From d089df385e737f71067309ff7abae15994d581ec Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 8 Aug 2019 16:54:33 -0400 Subject: [PATCH] Black (#3142) * precommit hook * azure checks * readme * readme * azure checks * flake8 & isort settings * black as separate azure job * black --check * notes on managing merge conflicts * format with black * flake8 fixes * move comment back --- .github/PULL_REQUEST_TEMPLATE.md | 1 + .pre-commit-config.yaml | 12 + README.rst | 3 + asv_bench/benchmarks/__init__.py | 2 + asv_bench/benchmarks/combine.py | 18 +- asv_bench/benchmarks/dataarray_missing.py | 43 +- asv_bench/benchmarks/dataset_io.py | 375 +- asv_bench/benchmarks/indexing.py | 99 +- asv_bench/benchmarks/interp.py | 33 +- asv_bench/benchmarks/reindexing.py | 30 +- asv_bench/benchmarks/rolling.py | 53 +- asv_bench/benchmarks/unstacking.py | 8 +- azure-pipelines.yml | 12 +- conftest.py | 20 +- doc/conf.py | 221 +- doc/contributing.rst | 54 +- doc/examples/_code/accessor_example.py | 4 +- doc/examples/_code/weather_data_setup.py | 12 +- doc/gallery/plot_cartopy_facetgrid.py | 15 +- doc/gallery/plot_colorbar_center.py | 26 +- doc/gallery/plot_control_colorbar.py | 9 +- doc/gallery/plot_lines_from_2d.py | 10 +- doc/gallery/plot_rasterio.py | 27 +- doc/gallery/plot_rasterio_rgb.py | 8 +- properties/test_encode_decode.py | 18 +- setup.cfg | 16 +- setup.py | 74 +- versioneer.py | 258 +- xarray/__init__.py | 16 +- xarray/_version.py | 154 +- xarray/backends/__init__.py | 26 +- xarray/backends/api.py | 616 ++- xarray/backends/cfgrib_.py | 16 +- xarray/backends/common.py | 107 +- xarray/backends/file_manager.py | 56 +- xarray/backends/h5netcdf_.py | 143 +- xarray/backends/locks.py | 22 +- xarray/backends/lru_cache.py | 7 +- xarray/backends/netCDF4_.py | 308 +- xarray/backends/netcdf3.py | 54 +- xarray/backends/pseudonetcdf_.py | 24 +- xarray/backends/pydap_.py | 27 +- xarray/backends/pynio_.py | 25 +- xarray/backends/rasterio_.py | 125 +- xarray/backends/scipy_.py | 116 +- xarray/backends/zarr.py | 275 +- xarray/coding/cftime_offsets.py | 342 +- xarray/coding/cftimeindex.py | 180 +- xarray/coding/strings.py | 116 +- xarray/coding/times.py | 241 +- xarray/coding/variables.py | 142 +- xarray/conventions.py | 366 +- xarray/convert.py | 160 +- xarray/core/accessor_dt.py | 87 +- xarray/core/accessor_str.py | 275 +- xarray/core/alignment.py | 170 +- xarray/core/arithmetic.py | 87 +- xarray/core/combine.py | 357 +- xarray/core/common.py | 371 +- xarray/core/computation.py | 616 ++- xarray/core/concat.py | 186 +- xarray/core/coordinates.py | 94 +- xarray/core/dask_array_compat.py | 62 +- xarray/core/dask_array_ops.py | 27 +- xarray/core/dataarray.py | 838 ++-- xarray/core/dataset.py | 1526 +++--- xarray/core/dtypes.py | 16 +- xarray/core/duck_array_ops.py | 200 +- xarray/core/extensions.py | 10 +- xarray/core/formatting.py | 324 +- xarray/core/groupby.py | 229 +- xarray/core/indexes.py | 13 +- xarray/core/indexing.py | 344 +- xarray/core/merge.py | 187 +- xarray/core/missing.py | 294 +- xarray/core/nanops.py | 85 +- xarray/core/npcompat.py | 60 +- xarray/core/nputils.py | 64 +- xarray/core/ops.py | 147 +- xarray/core/options.py | 54 +- xarray/core/pdcompat.py | 3 +- xarray/core/pycompat.py | 4 +- xarray/core/resample.py | 60 +- xarray/core/resample_cftime.py | 84 +- xarray/core/rolling.py | 185 +- xarray/core/rolling_exp.py | 15 +- xarray/core/utils.py | 179 +- xarray/core/variable.py | 647 +-- xarray/plot/__init__.py | 18 +- xarray/plot/dataset_plot.py | 312 +- xarray/plot/facetgrid.py | 248 +- xarray/plot/plot.py | 433 +- xarray/plot/utils.py | 300 +- xarray/testing.py | 132 +- xarray/tests/__init__.py | 101 +- xarray/tests/test_accessor_dt.py | 245 +- xarray/tests/test_accessor_str.py | 442 +- xarray/tests/test_backends.py | 3082 ++++++------ xarray/tests/test_backends_api.py | 15 +- xarray/tests/test_backends_common.py | 10 +- xarray/tests/test_backends_file_manager.py | 103 +- xarray/tests/test_backends_locks.py | 6 +- xarray/tests/test_backends_lru_cache.py | 70 +- xarray/tests/test_cftime_offsets.py | 1180 +++-- xarray/tests/test_cftimeindex.py | 673 +-- xarray/tests/test_cftimeindex_resample.py | 166 +- xarray/tests/test_coding.py | 33 +- xarray/tests/test_coding_strings.py | 125 +- xarray/tests/test_coding_times.py | 815 ++-- xarray/tests/test_combine.py | 700 +-- xarray/tests/test_computation.py | 811 ++-- xarray/tests/test_concat.py | 431 +- xarray/tests/test_conventions.py | 273 +- xarray/tests/test_dask.py | 487 +- xarray/tests/test_dataarray.py | 3880 ++++++++------- xarray/tests/test_dataset.py | 5110 +++++++++++--------- xarray/tests/test_distributed.py | 162 +- xarray/tests/test_dtypes.py | 84 +- xarray/tests/test_duck_array_ops.py | 373 +- xarray/tests/test_extensions.py | 34 +- xarray/tests/test_formatting.py | 250 +- xarray/tests/test_groupby.py | 160 +- xarray/tests/test_indexing.py | 486 +- xarray/tests/test_interp.py | 554 ++- xarray/tests/test_merge.py | 211 +- xarray/tests/test_missing.py | 299 +- xarray/tests/test_nputils.py | 18 +- xarray/tests/test_options.py | 65 +- xarray/tests/test_plot.py | 1415 +++--- xarray/tests/test_print_versions.py | 2 +- xarray/tests/test_sparse.py | 926 ++-- xarray/tests/test_tutorial.py | 13 +- xarray/tests/test_ufuncs.py | 115 +- xarray/tests/test_utils.py | 153 +- xarray/tests/test_variable.py | 1566 +++--- xarray/tutorial.py | 61 +- xarray/ufuncs.py | 35 +- xarray/util/print_versions.py | 52 +- 138 files changed, 22296 insertions(+), 17964 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 9d4024a69f6..34b9997ba07 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,4 +2,5 @@ - [ ] Closes #xxxx - [ ] Tests added + - [ ] Passes `black .` & `flake8` - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..bb627986e0f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +# https://pre-commit.com/ +# https://github.com/python/black#version-control-integration +repos: + - repo: https://github.com/python/black + rev: stable + hooks: + - id: black + language_version: python3.7 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.2.3 + hooks: + - id: flake8 diff --git a/README.rst b/README.rst index f2f754f37c9..53f51392a1a 100644 --- a/README.rst +++ b/README.rst @@ -11,6 +11,9 @@ xarray: N-D labeled arrays and datasets :target: https://pandas.pydata.org/speed/xarray/ .. image:: https://img.shields.io/pypi/v/xarray.svg :target: https://pypi.python.org/pypi/xarray/ +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/python/black + **xarray** (formerly **xray**) is an open source project and Python package that makes working with labelled multi-dimensional arrays simple, diff --git a/asv_bench/benchmarks/__init__.py b/asv_bench/benchmarks/__init__.py index d0eb6282fce..2ee5350ba90 100644 --- a/asv_bench/benchmarks/__init__.py +++ b/asv_bench/benchmarks/__init__.py @@ -12,6 +12,7 @@ def decorator(func): func.param_names = names func.params = params return func + return decorator @@ -28,6 +29,7 @@ def randn(shape, frac_nan=None, chunks=None, seed=0): x = rng.standard_normal(shape) else: import dask.array as da + rng = da.random.RandomState(seed) x = rng.standard_normal(shape, chunks=chunks) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 8670760abc1..9314361e998 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -13,22 +13,22 @@ def setup(self): data = np.random.randn(t_size, x_size, y_size) self.dsA0 = xr.Dataset( - {'A': xr.DataArray(data, coords={'T': t}, - dims=('T', 'X', 'Y'))}) + {"A": xr.DataArray(data, coords={"T": t}, dims=("T", "X", "Y"))} + ) self.dsA1 = xr.Dataset( - {'A': xr.DataArray(data, coords={'T': t + t_size}, - dims=('T', 'X', 'Y'))}) + {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T", "X", "Y"))} + ) self.dsB0 = xr.Dataset( - {'B': xr.DataArray(data, coords={'T': t}, - dims=('T', 'X', 'Y'))}) + {"B": xr.DataArray(data, coords={"T": t}, dims=("T", "X", "Y"))} + ) self.dsB1 = xr.Dataset( - {'B': xr.DataArray(data, coords={'T': t + t_size}, - dims=('T', 'X', 'Y'))}) + {"B": xr.DataArray(data, coords={"T": t + t_size}, dims=("T", "X", "Y"))} + ) def time_combine_manual(self): datasets = [[self.dsA0, self.dsA1], [self.dsB0, self.dsB1]] - xr.combine_manual(datasets, concat_dim=[None, 't']) + xr.combine_manual(datasets, concat_dim=[None, "t"]) def time_auto_combine(self): """Also has to load and arrange t coordinate""" diff --git a/asv_bench/benchmarks/dataarray_missing.py b/asv_bench/benchmarks/dataarray_missing.py index 29a9e78f82c..b119d14e73a 100644 --- a/asv_bench/benchmarks/dataarray_missing.py +++ b/asv_bench/benchmarks/dataarray_missing.py @@ -14,9 +14,8 @@ def make_bench_data(shape, frac_nan, chunks): vals = randn(shape, frac_nan) - coords = {'time': pd.date_range('2000-01-01', freq='D', - periods=shape[0])} - da = xr.DataArray(vals, dims=('time', 'x', 'y'), coords=coords) + coords = {"time": pd.date_range("2000-01-01", freq="D", periods=shape[0])} + da = xr.DataArray(vals, dims=("time", "x", "y"), coords=coords) if chunks is not None: da = da.chunk(chunks) @@ -28,44 +27,50 @@ def time_interpolate_na(shape, chunks, method, limit): if chunks is not None: requires_dask() da = make_bench_data(shape, 0.1, chunks=chunks) - actual = da.interpolate_na(dim='time', method='linear', limit=limit) + actual = da.interpolate_na(dim="time", method="linear", limit=limit) if chunks is not None: actual = actual.compute() -time_interpolate_na.param_names = ['shape', 'chunks', 'method', 'limit'] -time_interpolate_na.params = ([(3650, 200, 400), (100, 25, 25)], - [None, {'x': 25, 'y': 25}], - ['linear', 'spline', 'quadratic', 'cubic'], - [None, 3]) +time_interpolate_na.param_names = ["shape", "chunks", "method", "limit"] +time_interpolate_na.params = ( + [(3650, 200, 400), (100, 25, 25)], + [None, {"x": 25, "y": 25}], + ["linear", "spline", "quadratic", "cubic"], + [None, 3], +) def time_ffill(shape, chunks, limit): da = make_bench_data(shape, 0.1, chunks=chunks) - actual = da.ffill(dim='time', limit=limit) + actual = da.ffill(dim="time", limit=limit) if chunks is not None: actual = actual.compute() -time_ffill.param_names = ['shape', 'chunks', 'limit'] -time_ffill.params = ([(3650, 200, 400), (100, 25, 25)], - [None, {'x': 25, 'y': 25}], - [None, 3]) +time_ffill.param_names = ["shape", "chunks", "limit"] +time_ffill.params = ( + [(3650, 200, 400), (100, 25, 25)], + [None, {"x": 25, "y": 25}], + [None, 3], +) def time_bfill(shape, chunks, limit): da = make_bench_data(shape, 0.1, chunks=chunks) - actual = da.bfill(dim='time', limit=limit) + actual = da.bfill(dim="time", limit=limit) if chunks is not None: actual = actual.compute() -time_bfill.param_names = ['shape', 'chunks', 'limit'] -time_bfill.params = ([(3650, 200, 400), (100, 25, 25)], - [None, {'x': 25, 'y': 25}], - [None, 3]) +time_bfill.param_names = ["shape", "chunks", "limit"] +time_bfill.params = ( + [(3650, 200, 400), (100, 25, 25)], + [None, {"x": 25, "y": 25}], + [None, 3], +) diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 07bcc6d71b4..1542d2c857a 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -16,7 +16,7 @@ pass -os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" class IOSingleNetCDF: @@ -25,7 +25,7 @@ class IOSingleNetCDF: xarray """ - timeout = 300. + timeout = 300.0 repeat = 1 number = 5 @@ -37,67 +37,74 @@ def make_ds(self): self.nx = 90 self.ny = 45 - self.block_chunks = {'time': self.nt / 4, - 'lon': self.nx / 3, - 'lat': self.ny / 3} - - self.time_chunks = {'time': int(self.nt / 36)} - - times = pd.date_range('1970-01-01', periods=self.nt, freq='D') - lons = xr.DataArray(np.linspace(0, 360, self.nx), dims=('lon', ), - attrs={'units': 'degrees east', - 'long_name': 'longitude'}) - lats = xr.DataArray(np.linspace(-90, 90, self.ny), dims=('lat', ), - attrs={'units': 'degrees north', - 'long_name': 'latitude'}) - self.ds['foo'] = xr.DataArray(randn((self.nt, self.nx, self.ny), - frac_nan=0.2), - coords={'lon': lons, 'lat': lats, - 'time': times}, - dims=('time', 'lon', 'lat'), - name='foo', encoding=None, - attrs={'units': 'foo units', - 'description': 'a description'}) - self.ds['bar'] = xr.DataArray(randn((self.nt, self.nx, self.ny), - frac_nan=0.2), - coords={'lon': lons, 'lat': lats, - 'time': times}, - dims=('time', 'lon', 'lat'), - name='bar', encoding=None, - attrs={'units': 'bar units', - 'description': 'a description'}) - self.ds['baz'] = xr.DataArray(randn((self.nx, self.ny), - frac_nan=0.2).astype(np.float32), - coords={'lon': lons, 'lat': lats}, - dims=('lon', 'lat'), - name='baz', encoding=None, - attrs={'units': 'baz units', - 'description': 'a description'}) - - self.ds.attrs = {'history': 'created for xarray benchmarking'} - - self.oinds = {'time': randint(0, self.nt, 120), - 'lon': randint(0, self.nx, 20), - 'lat': randint(0, self.ny, 10)} - self.vinds = {'time': xr.DataArray(randint(0, self.nt, 120), - dims='x'), - 'lon': xr.DataArray(randint(0, self.nx, 120), - dims='x'), - 'lat': slice(3, 20)} + self.block_chunks = { + "time": self.nt / 4, + "lon": self.nx / 3, + "lat": self.ny / 3, + } + + self.time_chunks = {"time": int(self.nt / 36)} + + times = pd.date_range("1970-01-01", periods=self.nt, freq="D") + lons = xr.DataArray( + np.linspace(0, 360, self.nx), + dims=("lon",), + attrs={"units": "degrees east", "long_name": "longitude"}, + ) + lats = xr.DataArray( + np.linspace(-90, 90, self.ny), + dims=("lat",), + attrs={"units": "degrees north", "long_name": "latitude"}, + ) + self.ds["foo"] = xr.DataArray( + randn((self.nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="foo", + encoding=None, + attrs={"units": "foo units", "description": "a description"}, + ) + self.ds["bar"] = xr.DataArray( + randn((self.nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="bar", + encoding=None, + attrs={"units": "bar units", "description": "a description"}, + ) + self.ds["baz"] = xr.DataArray( + randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32), + coords={"lon": lons, "lat": lats}, + dims=("lon", "lat"), + name="baz", + encoding=None, + attrs={"units": "baz units", "description": "a description"}, + ) + + self.ds.attrs = {"history": "created for xarray benchmarking"} + + self.oinds = { + "time": randint(0, self.nt, 120), + "lon": randint(0, self.nx, 20), + "lat": randint(0, self.ny, 10), + } + self.vinds = { + "time": xr.DataArray(randint(0, self.nt, 120), dims="x"), + "lon": xr.DataArray(randint(0, self.nx, 120), dims="x"), + "lat": slice(3, 20), + } class IOWriteSingleNetCDF3(IOSingleNetCDF): def setup(self): - self.format = 'NETCDF3_64BIT' + self.format = "NETCDF3_64BIT" self.make_ds() def time_write_dataset_netcdf4(self): - self.ds.to_netcdf('test_netcdf4_write.nc', engine='netcdf4', - format=self.format) + self.ds.to_netcdf("test_netcdf4_write.nc", engine="netcdf4", format=self.format) def time_write_dataset_scipy(self): - self.ds.to_netcdf('test_scipy_write.nc', engine='scipy', - format=self.format) + self.ds.to_netcdf("test_scipy_write.nc", engine="scipy", format=self.format) class IOReadSingleNetCDF4(IOSingleNetCDF): @@ -105,19 +112,19 @@ def setup(self): self.make_ds() - self.filepath = 'test_single_file.nc4.nc' - self.format = 'NETCDF4' + self.filepath = "test_single_file.nc4.nc" + self.format = "NETCDF4" self.ds.to_netcdf(self.filepath, format=self.format) def time_load_dataset_netcdf4(self): - xr.open_dataset(self.filepath, engine='netcdf4').load() + xr.open_dataset(self.filepath, engine="netcdf4").load() def time_orthogonal_indexing(self): - ds = xr.open_dataset(self.filepath, engine='netcdf4') + ds = xr.open_dataset(self.filepath, engine="netcdf4") ds = ds.isel(**self.oinds).load() def time_vectorized_indexing(self): - ds = xr.open_dataset(self.filepath, engine='netcdf4') + ds = xr.open_dataset(self.filepath, engine="netcdf4") ds = ds.isel(**self.vinds).load() @@ -126,19 +133,19 @@ def setup(self): self.make_ds() - self.filepath = 'test_single_file.nc3.nc' - self.format = 'NETCDF3_64BIT' + self.filepath = "test_single_file.nc3.nc" + self.format = "NETCDF3_64BIT" self.ds.to_netcdf(self.filepath, format=self.format) def time_load_dataset_scipy(self): - xr.open_dataset(self.filepath, engine='scipy').load() + xr.open_dataset(self.filepath, engine="scipy").load() def time_orthogonal_indexing(self): - ds = xr.open_dataset(self.filepath, engine='scipy') + ds = xr.open_dataset(self.filepath, engine="scipy") ds = ds.isel(**self.oinds).load() def time_vectorized_indexing(self): - ds = xr.open_dataset(self.filepath, engine='scipy') + ds = xr.open_dataset(self.filepath, engine="scipy") ds = ds.isel(**self.vinds).load() @@ -149,37 +156,37 @@ def setup(self): self.make_ds() - self.filepath = 'test_single_file.nc4.nc' - self.format = 'NETCDF4' + self.filepath = "test_single_file.nc4.nc" + self.format = "NETCDF4" self.ds.to_netcdf(self.filepath, format=self.format) def time_load_dataset_netcdf4_with_block_chunks(self): - xr.open_dataset(self.filepath, engine='netcdf4', - chunks=self.block_chunks).load() + xr.open_dataset( + self.filepath, engine="netcdf4", chunks=self.block_chunks + ).load() def time_load_dataset_netcdf4_with_block_chunks_oindexing(self): - ds = xr.open_dataset(self.filepath, engine='netcdf4', - chunks=self.block_chunks) + ds = xr.open_dataset(self.filepath, engine="netcdf4", chunks=self.block_chunks) ds = ds.isel(**self.oinds).load() def time_load_dataset_netcdf4_with_block_chunks_vindexing(self): - ds = xr.open_dataset(self.filepath, engine='netcdf4', - chunks=self.block_chunks) + ds = xr.open_dataset(self.filepath, engine="netcdf4", chunks=self.block_chunks) ds = ds.isel(**self.vinds).load() def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_dataset(self.filepath, engine='netcdf4', - chunks=self.block_chunks).load() + xr.open_dataset( + self.filepath, engine="netcdf4", chunks=self.block_chunks + ).load() def time_load_dataset_netcdf4_with_time_chunks(self): - xr.open_dataset(self.filepath, engine='netcdf4', - chunks=self.time_chunks).load() + xr.open_dataset(self.filepath, engine="netcdf4", chunks=self.time_chunks).load() def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_dataset(self.filepath, engine='netcdf4', - chunks=self.time_chunks).load() + xr.open_dataset( + self.filepath, engine="netcdf4", chunks=self.time_chunks + ).load() class IOReadSingleNetCDF3Dask(IOReadSingleNetCDF4Dask): @@ -189,29 +196,29 @@ def setup(self): self.make_ds() - self.filepath = 'test_single_file.nc3.nc' - self.format = 'NETCDF3_64BIT' + self.filepath = "test_single_file.nc3.nc" + self.format = "NETCDF3_64BIT" self.ds.to_netcdf(self.filepath, format=self.format) def time_load_dataset_scipy_with_block_chunks(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_dataset(self.filepath, engine='scipy', - chunks=self.block_chunks).load() + xr.open_dataset( + self.filepath, engine="scipy", chunks=self.block_chunks + ).load() def time_load_dataset_scipy_with_block_chunks_oindexing(self): - ds = xr.open_dataset(self.filepath, engine='scipy', - chunks=self.block_chunks) + ds = xr.open_dataset(self.filepath, engine="scipy", chunks=self.block_chunks) ds = ds.isel(**self.oinds).load() def time_load_dataset_scipy_with_block_chunks_vindexing(self): - ds = xr.open_dataset(self.filepath, engine='scipy', - chunks=self.block_chunks) + ds = xr.open_dataset(self.filepath, engine="scipy", chunks=self.block_chunks) ds = ds.isel(**self.vinds).load() def time_load_dataset_scipy_with_time_chunks(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_dataset(self.filepath, engine='scipy', - chunks=self.time_chunks).load() + xr.open_dataset( + self.filepath, engine="scipy", chunks=self.time_chunks + ).load() class IOMultipleNetCDF: @@ -220,7 +227,7 @@ class IOMultipleNetCDF: xarray """ - timeout = 300. + timeout = 300.0 repeat = 1 number = 5 @@ -233,71 +240,78 @@ def make_ds(self, nfiles=10): self.ny = 45 self.nfiles = nfiles - self.block_chunks = {'time': self.nt / 4, - 'lon': self.nx / 3, - 'lat': self.ny / 3} + self.block_chunks = { + "time": self.nt / 4, + "lon": self.nx / 3, + "lat": self.ny / 3, + } - self.time_chunks = {'time': int(self.nt / 36)} + self.time_chunks = {"time": int(self.nt / 36)} self.time_vars = np.split( - pd.date_range('1970-01-01', periods=self.nt, freq='D'), - self.nfiles) + pd.date_range("1970-01-01", periods=self.nt, freq="D"), self.nfiles + ) self.ds_list = [] self.filenames_list = [] for i, times in enumerate(self.time_vars): ds = xr.Dataset() nt = len(times) - lons = xr.DataArray(np.linspace(0, 360, self.nx), dims=('lon', ), - attrs={'units': 'degrees east', - 'long_name': 'longitude'}) - lats = xr.DataArray(np.linspace(-90, 90, self.ny), dims=('lat', ), - attrs={'units': 'degrees north', - 'long_name': 'latitude'}) - ds['foo'] = xr.DataArray(randn((nt, self.nx, self.ny), - frac_nan=0.2), - coords={'lon': lons, 'lat': lats, - 'time': times}, - dims=('time', 'lon', 'lat'), - name='foo', encoding=None, - attrs={'units': 'foo units', - 'description': 'a description'}) - ds['bar'] = xr.DataArray(randn((nt, self.nx, self.ny), - frac_nan=0.2), - coords={'lon': lons, 'lat': lats, - 'time': times}, - dims=('time', 'lon', 'lat'), - name='bar', encoding=None, - attrs={'units': 'bar units', - 'description': 'a description'}) - ds['baz'] = xr.DataArray(randn((self.nx, self.ny), - frac_nan=0.2).astype(np.float32), - coords={'lon': lons, 'lat': lats}, - dims=('lon', 'lat'), - name='baz', encoding=None, - attrs={'units': 'baz units', - 'description': 'a description'}) - - ds.attrs = {'history': 'created for xarray benchmarking'} + lons = xr.DataArray( + np.linspace(0, 360, self.nx), + dims=("lon",), + attrs={"units": "degrees east", "long_name": "longitude"}, + ) + lats = xr.DataArray( + np.linspace(-90, 90, self.ny), + dims=("lat",), + attrs={"units": "degrees north", "long_name": "latitude"}, + ) + ds["foo"] = xr.DataArray( + randn((nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="foo", + encoding=None, + attrs={"units": "foo units", "description": "a description"}, + ) + ds["bar"] = xr.DataArray( + randn((nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="bar", + encoding=None, + attrs={"units": "bar units", "description": "a description"}, + ) + ds["baz"] = xr.DataArray( + randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32), + coords={"lon": lons, "lat": lats}, + dims=("lon", "lat"), + name="baz", + encoding=None, + attrs={"units": "baz units", "description": "a description"}, + ) + + ds.attrs = {"history": "created for xarray benchmarking"} self.ds_list.append(ds) - self.filenames_list.append('test_netcdf_%i.nc' % i) + self.filenames_list.append("test_netcdf_%i.nc" % i) class IOWriteMultipleNetCDF3(IOMultipleNetCDF): def setup(self): self.make_ds() - self.format = 'NETCDF3_64BIT' + self.format = "NETCDF3_64BIT" def time_write_dataset_netcdf4(self): - xr.save_mfdataset(self.ds_list, self.filenames_list, - engine='netcdf4', - format=self.format) + xr.save_mfdataset( + self.ds_list, self.filenames_list, engine="netcdf4", format=self.format + ) def time_write_dataset_scipy(self): - xr.save_mfdataset(self.ds_list, self.filenames_list, - engine='scipy', - format=self.format) + xr.save_mfdataset( + self.ds_list, self.filenames_list, engine="scipy", format=self.format + ) class IOReadMultipleNetCDF4(IOMultipleNetCDF): @@ -306,15 +320,14 @@ def setup(self): requires_dask() self.make_ds() - self.format = 'NETCDF4' - xr.save_mfdataset(self.ds_list, self.filenames_list, - format=self.format) + self.format = "NETCDF4" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) def time_load_dataset_netcdf4(self): - xr.open_mfdataset(self.filenames_list, engine='netcdf4').load() + xr.open_mfdataset(self.filenames_list, engine="netcdf4").load() def time_open_dataset_netcdf4(self): - xr.open_mfdataset(self.filenames_list, engine='netcdf4') + xr.open_mfdataset(self.filenames_list, engine="netcdf4") class IOReadMultipleNetCDF3(IOReadMultipleNetCDF4): @@ -323,15 +336,14 @@ def setup(self): requires_dask() self.make_ds() - self.format = 'NETCDF3_64BIT' - xr.save_mfdataset(self.ds_list, self.filenames_list, - format=self.format) + self.format = "NETCDF3_64BIT" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) def time_load_dataset_scipy(self): - xr.open_mfdataset(self.filenames_list, engine='scipy').load() + xr.open_mfdataset(self.filenames_list, engine="scipy").load() def time_open_dataset_scipy(self): - xr.open_mfdataset(self.filenames_list, engine='scipy') + xr.open_mfdataset(self.filenames_list, engine="scipy") class IOReadMultipleNetCDF4Dask(IOMultipleNetCDF): @@ -340,45 +352,52 @@ def setup(self): requires_dask() self.make_ds() - self.format = 'NETCDF4' - xr.save_mfdataset(self.ds_list, self.filenames_list, - format=self.format) + self.format = "NETCDF4" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) def time_load_dataset_netcdf4_with_block_chunks(self): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.block_chunks).load() + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ).load() def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.block_chunks).load() + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ).load() def time_load_dataset_netcdf4_with_time_chunks(self): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.time_chunks).load() + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ).load() def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.time_chunks).load() + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ).load() def time_open_dataset_netcdf4_with_block_chunks(self): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.block_chunks) + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ) def time_open_dataset_netcdf4_with_block_chunks_multiprocessing(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.block_chunks) + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ) def time_open_dataset_netcdf4_with_time_chunks(self): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.time_chunks) + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ) def time_open_dataset_netcdf4_with_time_chunks_multiprocessing(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='netcdf4', - chunks=self.time_chunks) + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ) class IOReadMultipleNetCDF3Dask(IOReadMultipleNetCDF4Dask): @@ -387,36 +406,40 @@ def setup(self): requires_dask() self.make_ds() - self.format = 'NETCDF3_64BIT' - xr.save_mfdataset(self.ds_list, self.filenames_list, - format=self.format) + self.format = "NETCDF3_64BIT" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) def time_load_dataset_scipy_with_block_chunks(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='scipy', - chunks=self.block_chunks).load() + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.block_chunks + ).load() def time_load_dataset_scipy_with_time_chunks(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='scipy', - chunks=self.time_chunks).load() + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.time_chunks + ).load() def time_open_dataset_scipy_with_block_chunks(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='scipy', - chunks=self.block_chunks) + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.block_chunks + ) def time_open_dataset_scipy_with_time_chunks(self): with dask.config.set(scheduler="multiprocessing"): - xr.open_mfdataset(self.filenames_list, engine='scipy', - chunks=self.time_chunks) + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.time_chunks + ) def create_delayed_write(): import dask.array as da + vals = da.random.random(300, chunks=(1,)) - ds = xr.Dataset({'vals': (['a'], vals)}) - return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + ds = xr.Dataset({"vals": (["a"], vals)}) + return ds.to_netcdf("file.nc", engine="netcdf4", compute=False) class IOWriteNetCDFDask: diff --git a/asv_bench/benchmarks/indexing.py b/asv_bench/benchmarks/indexing.py index 92f9351753a..c9e367dc696 100644 --- a/asv_bench/benchmarks/indexing.py +++ b/asv_bench/benchmarks/indexing.py @@ -12,81 +12,87 @@ nt = 1000 basic_indexes = { - '1slice': {'x': slice(0, 3)}, - '1slice-1scalar': {'x': 0, 'y': slice(None, None, 3)}, - '2slicess-1scalar': {'x': slice(3, -3, 3), 'y': 1, 't': slice(None, -3, 3)} + "1slice": {"x": slice(0, 3)}, + "1slice-1scalar": {"x": 0, "y": slice(None, None, 3)}, + "2slicess-1scalar": {"x": slice(3, -3, 3), "y": 1, "t": slice(None, -3, 3)}, } basic_assignment_values = { - '1slice': xr.DataArray(randn((3, ny), frac_nan=0.1), dims=['x', 'y']), - '1slice-1scalar': xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), - dims=['y']), - '2slicess-1scalar': xr.DataArray(randn(int((nx - 6) / 3), frac_nan=0.1), - dims=['x']) + "1slice": xr.DataArray(randn((3, ny), frac_nan=0.1), dims=["x", "y"]), + "1slice-1scalar": xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), dims=["y"]), + "2slicess-1scalar": xr.DataArray( + randn(int((nx - 6) / 3), frac_nan=0.1), dims=["x"] + ), } outer_indexes = { - '1d': {'x': randint(0, nx, 400)}, - '2d': {'x': randint(0, nx, 500), 'y': randint(0, ny, 400)}, - '2d-1scalar': {'x': randint(0, nx, 100), 'y': 1, 't': randint(0, nt, 400)} + "1d": {"x": randint(0, nx, 400)}, + "2d": {"x": randint(0, nx, 500), "y": randint(0, ny, 400)}, + "2d-1scalar": {"x": randint(0, nx, 100), "y": 1, "t": randint(0, nt, 400)}, } outer_assignment_values = { - '1d': xr.DataArray(randn((400, ny), frac_nan=0.1), dims=['x', 'y']), - '2d': xr.DataArray(randn((500, 400), frac_nan=0.1), dims=['x', 'y']), - '2d-1scalar': xr.DataArray(randn(100, frac_nan=0.1), dims=['x']) + "1d": xr.DataArray(randn((400, ny), frac_nan=0.1), dims=["x", "y"]), + "2d": xr.DataArray(randn((500, 400), frac_nan=0.1), dims=["x", "y"]), + "2d-1scalar": xr.DataArray(randn(100, frac_nan=0.1), dims=["x"]), } vectorized_indexes = { - '1-1d': {'x': xr.DataArray(randint(0, nx, 400), dims='a')}, - '2-1d': {'x': xr.DataArray(randint(0, nx, 400), dims='a'), - 'y': xr.DataArray(randint(0, ny, 400), dims='a')}, - '3-2d': {'x': xr.DataArray(randint(0, nx, 400).reshape(4, 100), - dims=['a', 'b']), - 'y': xr.DataArray(randint(0, ny, 400).reshape(4, 100), - dims=['a', 'b']), - 't': xr.DataArray(randint(0, nt, 400).reshape(4, 100), - dims=['a', 'b'])}, + "1-1d": {"x": xr.DataArray(randint(0, nx, 400), dims="a")}, + "2-1d": { + "x": xr.DataArray(randint(0, nx, 400), dims="a"), + "y": xr.DataArray(randint(0, ny, 400), dims="a"), + }, + "3-2d": { + "x": xr.DataArray(randint(0, nx, 400).reshape(4, 100), dims=["a", "b"]), + "y": xr.DataArray(randint(0, ny, 400).reshape(4, 100), dims=["a", "b"]), + "t": xr.DataArray(randint(0, nt, 400).reshape(4, 100), dims=["a", "b"]), + }, } vectorized_assignment_values = { - '1-1d': xr.DataArray(randn((400, 2000)), dims=['a', 'y'], - coords={'a': randn(400)}), - '2-1d': xr.DataArray(randn(400), dims=['a', ], coords={'a': randn(400)}), - '3-2d': xr.DataArray(randn((4, 100)), dims=['a', 'b'], - coords={'a': randn(4), 'b': randn(100)}) + "1-1d": xr.DataArray(randn((400, 2000)), dims=["a", "y"], coords={"a": randn(400)}), + "2-1d": xr.DataArray(randn(400), dims=["a"], coords={"a": randn(400)}), + "3-2d": xr.DataArray( + randn((4, 100)), dims=["a", "b"], coords={"a": randn(4), "b": randn(100)} + ), } class Base: def setup(self, key): self.ds = xr.Dataset( - {'var1': (('x', 'y'), randn((nx, ny), frac_nan=0.1)), - 'var2': (('x', 't'), randn((nx, nt))), - 'var3': (('t', ), randn(nt))}, - coords={'x': np.arange(nx), - 'y': np.linspace(0, 1, ny), - 't': pd.date_range('1970-01-01', periods=nt, freq='D'), - 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) + { + "var1": (("x", "y"), randn((nx, ny), frac_nan=0.1)), + "var2": (("x", "t"), randn((nx, nt))), + "var3": (("t",), randn(nt)), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) class Indexing(Base): def time_indexing_basic(self, key): self.ds.isel(**basic_indexes[key]).load() - time_indexing_basic.param_names = ['key'] + time_indexing_basic.param_names = ["key"] time_indexing_basic.params = [list(basic_indexes.keys())] def time_indexing_outer(self, key): self.ds.isel(**outer_indexes[key]).load() - time_indexing_outer.param_names = ['key'] + time_indexing_outer.param_names = ["key"] time_indexing_outer.params = [list(outer_indexes.keys())] def time_indexing_vectorized(self, key): self.ds.isel(**vectorized_indexes[key]).load() - time_indexing_vectorized.param_names = ['key'] + time_indexing_vectorized.param_names = ["key"] time_indexing_vectorized.params = [list(vectorized_indexes.keys())] @@ -94,28 +100,25 @@ class Assignment(Base): def time_assignment_basic(self, key): ind = basic_indexes[key] val = basic_assignment_values[key] - self.ds['var1'][ind.get('x', slice(None)), - ind.get('y', slice(None))] = val + self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val - time_assignment_basic.param_names = ['key'] + time_assignment_basic.param_names = ["key"] time_assignment_basic.params = [list(basic_indexes.keys())] def time_assignment_outer(self, key): ind = outer_indexes[key] val = outer_assignment_values[key] - self.ds['var1'][ind.get('x', slice(None)), - ind.get('y', slice(None))] = val + self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val - time_assignment_outer.param_names = ['key'] + time_assignment_outer.param_names = ["key"] time_assignment_outer.params = [list(outer_indexes.keys())] def time_assignment_vectorized(self, key): ind = vectorized_indexes[key] val = vectorized_assignment_values[key] - self.ds['var1'][ind.get('x', slice(None)), - ind.get('y', slice(None))] = val + self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val - time_assignment_vectorized.param_names = ['key'] + time_assignment_vectorized.param_names = ["key"] time_assignment_vectorized.params = [list(vectorized_indexes.keys())] @@ -123,4 +126,4 @@ class IndexingDask(Indexing): def setup(self, key): requires_dask() super().setup(key) - self.ds = self.ds.chunk({'x': 100, 'y': 50, 't': 50}) + self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50}) diff --git a/asv_bench/benchmarks/interp.py b/asv_bench/benchmarks/interp.py index 31a64a02295..b62717d7ceb 100644 --- a/asv_bench/benchmarks/interp.py +++ b/asv_bench/benchmarks/interp.py @@ -15,8 +15,8 @@ randn_xy = randn((nx, ny), frac_nan=0.1) randn_xt = randn((nx, nt)) -randn_t = randn((nt, )) -randn_long = randn((long_nx, ), frac_nan=0.1) +randn_t = randn((nt,)) +randn_long = randn((long_nx,), frac_nan=0.1) new_x_short = np.linspace(0.3 * nx, 0.7 * nx, 100) @@ -27,22 +27,25 @@ class Interpolation: def setup(self, *args, **kwargs): self.ds = xr.Dataset( - {'var1': (('x', 'y'), randn_xy), - 'var2': (('x', 't'), randn_xt), - 'var3': (('t', ), randn_t)}, - coords={'x': np.arange(nx), - 'y': np.linspace(0, 1, ny), - 't': pd.date_range('1970-01-01', periods=nt, freq='D'), - 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) - - @parameterized(['method', 'is_short'], - (['linear', 'cubic'], [True, False])) + { + "var1": (("x", "y"), randn_xy), + "var2": (("x", "t"), randn_xt), + "var3": (("t",), randn_t), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + + @parameterized(["method", "is_short"], (["linear", "cubic"], [True, False])) def time_interpolation(self, method, is_short): new_x = new_x_short if is_short else new_x_long self.ds.interp(x=new_x, method=method).load() - @parameterized(['method'], - (['linear', 'nearest'])) + @parameterized(["method"], (["linear", "nearest"])) def time_interpolation_2d(self, method): self.ds.interp(x=new_x_long, y=new_y_long, method=method).load() @@ -51,4 +54,4 @@ class InterpolationDask(Interpolation): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds = self.ds.chunk({'t': 50}) + self.ds = self.ds.chunk({"t": 50}) diff --git a/asv_bench/benchmarks/reindexing.py b/asv_bench/benchmarks/reindexing.py index da00df37d19..ceed186fcc8 100644 --- a/asv_bench/benchmarks/reindexing.py +++ b/asv_bench/benchmarks/reindexing.py @@ -10,35 +10,41 @@ class Reindex: def setup(self): data = np.random.RandomState(0).randn(1000, 100, 100) - self.ds = xr.Dataset({'temperature': (('time', 'x', 'y'), data)}, - coords={'time': np.arange(1000), - 'x': np.arange(100), - 'y': np.arange(100)}) + self.ds = xr.Dataset( + {"temperature": (("time", "x", "y"), data)}, + coords={"time": np.arange(1000), "x": np.arange(100), "y": np.arange(100)}, + ) def time_1d_coarse(self): self.ds.reindex(time=np.arange(0, 1000, 5)).load() def time_1d_fine_all_found(self): - self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest').load() + self.ds.reindex(time=np.arange(0, 1000, 0.5), method="nearest").load() def time_1d_fine_some_missing(self): - self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest', - tolerance=0.1).load() + self.ds.reindex( + time=np.arange(0, 1000, 0.5), method="nearest", tolerance=0.1 + ).load() def time_2d_coarse(self): self.ds.reindex(x=np.arange(0, 100, 2), y=np.arange(0, 100, 2)).load() def time_2d_fine_all_found(self): - self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5), - method='nearest').load() + self.ds.reindex( + x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5), method="nearest" + ).load() def time_2d_fine_some_missing(self): - self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5), - method='nearest', tolerance=0.1).load() + self.ds.reindex( + x=np.arange(0, 100, 0.5), + y=np.arange(0, 100, 0.5), + method="nearest", + tolerance=0.1, + ).load() class ReindexDask(Reindex): def setup(self): requires_dask() super().setup() - self.ds = self.ds.chunk({'time': 100}) + self.ds = self.ds.chunk({"time": 100}) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index efc76c01ed3..268bad7e738 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -15,30 +15,34 @@ randn_xy = randn((nx, ny), frac_nan=0.1) randn_xt = randn((nx, nt)) -randn_t = randn((nt, )) -randn_long = randn((long_nx, ), frac_nan=0.1) +randn_t = randn((nt,)) +randn_long = randn((long_nx,), frac_nan=0.1) class Rolling: def setup(self, *args, **kwargs): self.ds = xr.Dataset( - {'var1': (('x', 'y'), randn_xy), - 'var2': (('x', 't'), randn_xt), - 'var3': (('t', ), randn_t)}, - coords={'x': np.arange(nx), - 'y': np.linspace(0, 1, ny), - 't': pd.date_range('1970-01-01', periods=nt, freq='D'), - 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) - self.da_long = xr.DataArray(randn_long, dims='x', - coords={'x': np.arange(long_nx) * 0.1}) + { + "var1": (("x", "y"), randn_xy), + "var2": (("x", "t"), randn_xt), + "var3": (("t",), randn_t), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + self.da_long = xr.DataArray( + randn_long, dims="x", coords={"x": np.arange(long_nx) * 0.1} + ) - @parameterized(['func', 'center'], - (['mean', 'count'], [True, False])) + @parameterized(["func", "center"], (["mean", "count"], [True, False])) def time_rolling(self, func, center): getattr(self.ds.rolling(x=window, center=center), func)().load() - @parameterized(['func', 'pandas'], - (['mean', 'count'], [True, False])) + @parameterized(["func", "pandas"], (["mean", "count"], [True, False])) def time_rolling_long(self, func, pandas): if pandas: se = self.da_long.to_series() @@ -46,23 +50,22 @@ def time_rolling_long(self, func, pandas): else: getattr(self.da_long.rolling(x=window), func)().load() - @parameterized(['window_', 'min_periods'], - ([20, 40], [5, None])) + @parameterized(["window_", "min_periods"], ([20, 40], [5, None])) def time_rolling_np(self, window_, min_periods): - self.ds.rolling(x=window_, center=False, - min_periods=min_periods).reduce( - getattr(np, 'nanmean')).load() + self.ds.rolling(x=window_, center=False, min_periods=min_periods).reduce( + getattr(np, "nanmean") + ).load() - @parameterized(['center', 'stride'], - ([True, False], [1, 200])) + @parameterized(["center", "stride"], ([True, False], [1, 200])) def time_rolling_construct(self, center, stride): self.ds.rolling(x=window, center=center).construct( - 'window_dim', stride=stride).mean(dim='window_dim').load() + "window_dim", stride=stride + ).mean(dim="window_dim").load() class RollingDask(Rolling): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds = self.ds.chunk({'x': 100, 'y': 50, 't': 50}) - self.da_long = self.da_long.chunk({'x': 10000}) + self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50}) + self.da_long = self.da_long.chunk({"x": 10000}) diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index 2968ee3f2bb..7b529373902 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -10,17 +10,17 @@ class Unstacking: def setup(self): data = np.random.RandomState(0).randn(1, 1000, 500) - self.ds = xr.DataArray(data).stack(flat_dim=['dim_1', 'dim_2']) + self.ds = xr.DataArray(data).stack(flat_dim=["dim_1", "dim_2"]) def time_unstack_fast(self): - self.ds.unstack('flat_dim') + self.ds.unstack("flat_dim") def time_unstack_slow(self): - self.ds[:, ::-1].unstack('flat_dim') + self.ds[:, ::-1].unstack("flat_dim") class UnstackingDask(Unstacking): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds = self.ds.chunk({'flat_dim': 50}) + self.ds = self.ds.chunk({"flat_dim": 50}) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ccf8392c5a7..82d3b70704b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -46,7 +46,7 @@ jobs: steps: - template: ci/azure/unit-tests.yml -- job: Lint +- job: LintFlake8 pool: vmImage: 'ubuntu-16.04' steps: @@ -56,6 +56,16 @@ jobs: - bash: flake8 displayName: flake8 lint checks +- job: FormattingBlack + pool: + vmImage: 'ubuntu-16.04' + steps: + - task: UsePythonVersion@0 + - bash: python -m pip install black + displayName: Install black + - bash: black --check . + displayName: black formatting check + - job: TypeChecking variables: conda_env: py37 diff --git a/conftest.py b/conftest.py index 177e689591f..25dc284975e 100644 --- a/conftest.py +++ b/conftest.py @@ -5,17 +5,19 @@ def pytest_addoption(parser): """Add command-line flags for pytest.""" - parser.addoption("--run-flaky", action="store_true", - help="runs flaky tests") - parser.addoption("--run-network-tests", action="store_true", - help="runs tests requiring a network connection") + parser.addoption("--run-flaky", action="store_true", help="runs flaky tests") + parser.addoption( + "--run-network-tests", + action="store_true", + help="runs tests requiring a network connection", + ) def pytest_runtest_setup(item): # based on https://stackoverflow.com/questions/47559524 - if 'flaky' in item.keywords and not item.config.getoption("--run-flaky"): + if "flaky" in item.keywords and not item.config.getoption("--run-flaky"): pytest.skip("set --run-flaky option to run flaky tests") - if ('network' in item.keywords - and not item.config.getoption("--run-network-tests")): - pytest.skip("set --run-network-tests to run test requiring an " - "internet connection") + if "network" in item.keywords and not item.config.getoption("--run-network-tests"): + pytest.skip( + "set --run-network-tests to run test requiring an " "internet connection" + ) diff --git a/doc/conf.py b/doc/conf.py index 00d587807fa..e181c3e14c2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -26,158 +26,164 @@ print("python exec:", sys.executable) print("sys.path:", sys.path) -if 'conda' in sys.executable: - print('conda environment:') - subprocess.run(['conda', 'list']) +if "conda" in sys.executable: + print("conda environment:") + subprocess.run(["conda", "list"]) else: - print('pip environment:') - subprocess.run(['pip', 'list']) + print("pip environment:") + subprocess.run(["pip", "list"]) print("xarray: %s, %s" % (xarray.__version__, xarray.__file__)) with suppress(ImportError): import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") try: import rasterio except ImportError: - allowed_failures.update(['gallery/plot_rasterio_rgb.py', - 'gallery/plot_rasterio.py']) + allowed_failures.update( + ["gallery/plot_rasterio_rgb.py", "gallery/plot_rasterio.py"] + ) try: import cartopy except ImportError: - allowed_failures.update(['gallery/plot_cartopy_facetgrid.py', - 'gallery/plot_rasterio_rgb.py', - 'gallery/plot_rasterio.py']) + allowed_failures.update( + [ + "gallery/plot_cartopy_facetgrid.py", + "gallery/plot_rasterio_rgb.py", + "gallery/plot_rasterio.py", + ] + ) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.intersphinx', - 'sphinx.ext.extlinks', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'numpydoc', - 'IPython.sphinxext.ipython_directive', - 'IPython.sphinxext.ipython_console_highlighting', - 'sphinx_gallery.gen_gallery', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "numpydoc", + "IPython.sphinxext.ipython_directive", + "IPython.sphinxext.ipython_console_highlighting", + "sphinx_gallery.gen_gallery", ] -extlinks = {'issue': ('https://github.com/pydata/xarray/issues/%s', 'GH'), - 'pull': ('https://github.com/pydata/xarray/pull/%s', 'PR'), - } +extlinks = { + "issue": ("https://github.com/pydata/xarray/issues/%s", "GH"), + "pull": ("https://github.com/pydata/xarray/pull/%s", "PR"), +} -sphinx_gallery_conf = {'examples_dirs': 'gallery', - 'gallery_dirs': 'auto_gallery', - 'backreferences_dir': False, - 'expected_failing_examples': list(allowed_failures) - } +sphinx_gallery_conf = { + "examples_dirs": "gallery", + "gallery_dirs": "auto_gallery", + "backreferences_dir": False, + "expected_failing_examples": list(allowed_failures), +} autosummary_generate = True -autodoc_typehints = 'none' +autodoc_typehints = "none" napoleon_use_param = True -napoleon_use_rtype = True +napoleon_use_rtype = True numpydoc_class_members_toctree = True numpydoc_show_class_members = False # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'xarray' -copyright = '2014-%s, xarray Developers' % datetime.datetime.now().year +project = "xarray" +copyright = "2014-%s, xarray Developers" % datetime.datetime.now().year # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = xarray.__version__.split('+')[0] +version = xarray.__version__.split("+")[0] # The full version, including alpha/beta/rc tags. release = xarray.__version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -today_fmt = '%Y-%m-%d' +today_fmt = "%Y-%m-%d" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -html_theme_options = { - 'logo_only': True, -} +html_theme_options = {"logo_only": True} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. @@ -186,25 +192,26 @@ # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -html_favicon = '_static/favicon.ico' +html_favicon = "_static/favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Sometimes the savefig directory doesn't exist and needs to be created # https://github.com/ipython/ipython/issues/8733 # becomes obsolete when we can pin ipython>=5.2; see doc/environment.yml -ipython_savefig_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '_build','html','_static') +ipython_savefig_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "_build", "html", "_static" +) if not os.path.exists(ipython_savefig_dir): os.makedirs(ipython_savefig_dir) # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. @@ -212,98 +219,92 @@ # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. # html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'xarraydoc' +htmlhelp_basename = "xarraydoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'xarray.tex', 'xarray Documentation', - 'xarray Developers', 'manual'), + ("index", "xarray.tex", "xarray Documentation", "xarray Developers", "manual") ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'xarray', 'xarray Documentation', - ['xarray Developers'], 1) -] +man_pages = [("index", "xarray", "xarray Documentation", ["xarray Developers"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -312,30 +313,36 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'xarray', 'xarray Documentation', - 'xarray Developers', 'xarray', 'N-D labeled arrays and datasets in Python.', - 'Miscellaneous'), + ( + "index", + "xarray", + "xarray Documentation", + "xarray Developers", + "xarray", + "N-D labeled arrays and datasets in Python.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), - 'iris': ('http://scitools.org.uk/iris/docs/latest/', None), - 'numpy': ('https://docs.scipy.org/doc/numpy/', None), - 'numba': ('https://numba.pydata.org/numba-doc/latest/', None), - 'matplotlib': ('https://matplotlib.org/', None), + "python": ("https://docs.python.org/3/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "iris": ("http://scitools.org.uk/iris/docs/latest/", None), + "numpy": ("https://docs.scipy.org/doc/numpy/", None), + "numba": ("https://numba.pydata.org/numba-doc/latest/", None), + "matplotlib": ("https://matplotlib.org/", None), } diff --git a/doc/contributing.rst b/doc/contributing.rst index 9017c3dde7c..14ecb65f295 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -340,23 +340,24 @@ do not make sudden changes to the code that could have the potential to break a lot of user code as a result, that is, we need it to be as *backwards compatible* as possible to avoid mass breakages. -Python (PEP8) -~~~~~~~~~~~~~ +Code Formatting +~~~~~~~~~~~~~~~ -*xarray* uses the `PEP8 `_ standard. -There are several tools to ensure you abide by this standard. Here are *some* of -the more common ``PEP8`` issues: +Xarray uses `Black `_ and +`Flake8 `_ to ensure a consistent code +format throughout the project. ``black`` and ``flake8`` can be installed with +``pip``:: - - we restrict line-length to 79 characters to promote readability - - passing arguments should have spaces after commas, e.g. ``foo(arg1, arg2, kw1='bar')`` + pip install black flake8 -:ref:`Continuous Integration ` will run -the `flake8 `_ tool -and report any stylistic errors in your code. Therefore, it is helpful before -submitting code to run the check yourself: +and then run from the root of the Xarray repository:: + black . flake8 +to auto-format your code. Additionally, many editors have plugins that will +apply ``black`` as you edit files. + Other recommended but optional tools for checking code quality (not currently enforced in CI): @@ -367,8 +368,35 @@ enforced in CI): incorrectly sorted imports. ``isort -y`` will automatically fix them. See also `flake8-isort `_. -Note that your code editor probably supports extensions that can show results -of these checks inline as you type. +Optionally, you may wish to setup `pre-commit hooks `_ +to automatically run ``black`` and ``flake8`` when you make a git commit. This +can be done by installing ``pre-commit``:: + + pip install pre-commit + +and then running:: + + pre-commit install + +from the root of the Xarray repository. Now ``black`` and ``flake8`` will be run +each time you commit changes. You can skip these checks with +``git commit --no-verify``. + +.. note:: + + If you were working on a branch *prior* to the code being reformatted with black, + you will likely face some merge conflicts. These steps can eliminate many of those + conflicts. Because they have had limited testing, please reach out to the core devs + on your pull request if you face any issues, and we'll help with the merge: + + - Merge the commit on master prior to the `black` commit into your branch + `git merge f172c673`. Any conflicts are real conflicts and require resolving + - Apply `black .` to your branch and commit + # TODO: insert after the black commit is on master + - Merge master at the `black` commit, resolving in favor of 'our' changes: + `git merge [master-at-black-commit] -X ours`. You shouldn't have any merge conflicts + - Merge current master `git merge master`; any conflicts here are real and + again require resolving Backwards Compatibility ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/examples/_code/accessor_example.py b/doc/examples/_code/accessor_example.py index d179d38fba9..ffbacb479e8 100644 --- a/doc/examples/_code/accessor_example.py +++ b/doc/examples/_code/accessor_example.py @@ -1,7 +1,7 @@ import xarray as xr -@xr.register_dataset_accessor('geo') +@xr.register_dataset_accessor("geo") class GeoAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj @@ -20,4 +20,4 @@ def center(self): def plot(self): """Plot data on a map.""" - return 'plotting!' + return "plotting!" diff --git a/doc/examples/_code/weather_data_setup.py b/doc/examples/_code/weather_data_setup.py index d3a3e2d065a..385f5366ef7 100644 --- a/doc/examples/_code/weather_data_setup.py +++ b/doc/examples/_code/weather_data_setup.py @@ -6,13 +6,17 @@ np.random.seed(123) -times = pd.date_range('2000-01-01', '2001-12-31', name='time') +times = pd.date_range("2000-01-01", "2001-12-31", name="time") annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) base = 10 + 15 * annual_cycle.reshape(-1, 1) tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) -ds = xr.Dataset({'tmin': (('time', 'location'), tmin_values), - 'tmax': (('time', 'location'), tmax_values)}, - {'time': times, 'location': ['IA', 'IN', 'IL']}) +ds = xr.Dataset( + { + "tmin": (("time", "location"), tmin_values), + "tmax": (("time", "location"), tmax_values), + }, + {"time": times, "location": ["IA", "IN", "IL"]}, +) diff --git a/doc/gallery/plot_cartopy_facetgrid.py b/doc/gallery/plot_cartopy_facetgrid.py index 3eded115263..cfa9d3b07ec 100644 --- a/doc/gallery/plot_cartopy_facetgrid.py +++ b/doc/gallery/plot_cartopy_facetgrid.py @@ -22,16 +22,19 @@ import xarray as xr # Load the data -ds = xr.tutorial.load_dataset('air_temperature') +ds = xr.tutorial.load_dataset("air_temperature") air = ds.air.isel(time=[0, 724]) - 273.15 # This is the map projection we want to plot *onto* map_proj = ccrs.LambertConformal(central_longitude=-95, central_latitude=45) -p = air.plot(transform=ccrs.PlateCarree(), # the data's projection - col='time', col_wrap=1, # multiplot settings - aspect=ds.dims['lon'] / ds.dims['lat'], # for a sensible figsize - subplot_kws={'projection': map_proj}) # the plot's projection +p = air.plot( + transform=ccrs.PlateCarree(), # the data's projection + col="time", + col_wrap=1, # multiplot settings + aspect=ds.dims["lon"] / ds.dims["lat"], # for a sensible figsize + subplot_kws={"projection": map_proj}, # the plot's projection +) # We have to set the map's options on all four axes for ax in p.axes.flat: @@ -39,6 +42,6 @@ ax.set_extent([-160, -30, 5, 75]) # Without this aspect attributes the maps will look chaotic and the # "extent" attribute above will be ignored - ax.set_aspect('equal', 'box-forced') + ax.set_aspect("equal", "box-forced") plt.show() diff --git a/doc/gallery/plot_colorbar_center.py b/doc/gallery/plot_colorbar_center.py index 4818b737632..8227dc5ba0c 100644 --- a/doc/gallery/plot_colorbar_center.py +++ b/doc/gallery/plot_colorbar_center.py @@ -13,31 +13,31 @@ import xarray as xr # Load the data -ds = xr.tutorial.load_dataset('air_temperature') +ds = xr.tutorial.load_dataset("air_temperature") air = ds.air.isel(time=0) f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6)) # The first plot (in kelvins) chooses "viridis" and uses the data's min/max -air.plot(ax=ax1, cbar_kwargs={'label': 'K'}) -ax1.set_title('Kelvins: default') -ax2.set_xlabel('') +air.plot(ax=ax1, cbar_kwargs={"label": "K"}) +ax1.set_title("Kelvins: default") +ax2.set_xlabel("") # The second plot (in celsius) now chooses "BuRd" and centers min/max around 0 airc = air - 273.15 -airc.plot(ax=ax2, cbar_kwargs={'label': '°C'}) -ax2.set_title('Celsius: default') -ax2.set_xlabel('') -ax2.set_ylabel('') +airc.plot(ax=ax2, cbar_kwargs={"label": "°C"}) +ax2.set_title("Celsius: default") +ax2.set_xlabel("") +ax2.set_ylabel("") # The center doesn't have to be 0 -air.plot(ax=ax3, center=273.15, cbar_kwargs={'label': 'K'}) -ax3.set_title('Kelvins: center=273.15') +air.plot(ax=ax3, center=273.15, cbar_kwargs={"label": "K"}) +ax3.set_title("Kelvins: center=273.15") # Or it can be ignored -airc.plot(ax=ax4, center=False, cbar_kwargs={'label': '°C'}) -ax4.set_title('Celsius: center=False') -ax4.set_ylabel('') +airc.plot(ax=ax4, center=False, cbar_kwargs={"label": "°C"}) +ax4.set_title("Celsius: center=False") +ax4.set_ylabel("") # Mke it nice plt.tight_layout() diff --git a/doc/gallery/plot_control_colorbar.py b/doc/gallery/plot_control_colorbar.py index 5802a57cf31..bd1f2c69a44 100644 --- a/doc/gallery/plot_control_colorbar.py +++ b/doc/gallery/plot_control_colorbar.py @@ -12,7 +12,7 @@ import xarray as xr # Load the data -air_temp = xr.tutorial.load_dataset('air_temperature') +air_temp = xr.tutorial.load_dataset("air_temperature") air2d = air_temp.air.isel(time=500) # Prepare the figure @@ -23,9 +23,10 @@ # Plot data air2d.plot(ax=ax1, levels=levels) -air2d.plot(ax=ax2, levels=levels, cbar_kwargs={'ticks': levels}) -air2d.plot(ax=ax3, levels=levels, cbar_kwargs={'ticks': levels, - 'spacing': 'proportional'}) +air2d.plot(ax=ax2, levels=levels, cbar_kwargs={"ticks": levels}) +air2d.plot( + ax=ax3, levels=levels, cbar_kwargs={"ticks": levels, "spacing": "proportional"} +) # Show plots plt.tight_layout() diff --git a/doc/gallery/plot_lines_from_2d.py b/doc/gallery/plot_lines_from_2d.py index 93d7770238e..2aebda2f323 100644 --- a/doc/gallery/plot_lines_from_2d.py +++ b/doc/gallery/plot_lines_from_2d.py @@ -17,7 +17,7 @@ import xarray as xr # Load the data -ds = xr.tutorial.load_dataset('air_temperature') +ds = xr.tutorial.load_dataset("air_temperature") air = ds.air - 273.15 # to celsius # Prepare the figure @@ -27,12 +27,12 @@ isel_lats = [10, 15, 20] # Temperature vs longitude plot - illustrates the "hue" kwarg -air.isel(time=0, lat=isel_lats).plot.line(ax=ax1, hue='lat') -ax1.set_ylabel('°C') +air.isel(time=0, lat=isel_lats).plot.line(ax=ax1, hue="lat") +ax1.set_ylabel("°C") # Temperature vs time plot - illustrates the "x" and "add_legend" kwargs -air.isel(lon=30, lat=isel_lats).plot.line(ax=ax2, x='time', add_legend=False) -ax2.set_ylabel('') +air.isel(lon=30, lat=isel_lats).plot.line(ax=ax2, x="time", add_legend=False) +ax2.set_ylabel("") # Show plt.tight_layout() diff --git a/doc/gallery/plot_rasterio.py b/doc/gallery/plot_rasterio.py index 82d5ce61284..d5cbb0700cc 100644 --- a/doc/gallery/plot_rasterio.py +++ b/doc/gallery/plot_rasterio.py @@ -24,27 +24,32 @@ import xarray as xr # Read the data -url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif' +url = "https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif" da = xr.open_rasterio(url) # Compute the lon/lat coordinates with rasterio.warp.transform -ny, nx = len(da['y']), len(da['x']) -x, y = np.meshgrid(da['x'], da['y']) +ny, nx = len(da["y"]), len(da["x"]) +x, y = np.meshgrid(da["x"], da["y"]) # Rasterio works with 1D arrays -lon, lat = transform(da.crs, {'init': 'EPSG:4326'}, - x.flatten(), y.flatten()) +lon, lat = transform(da.crs, {"init": "EPSG:4326"}, x.flatten(), y.flatten()) lon = np.asarray(lon).reshape((ny, nx)) lat = np.asarray(lat).reshape((ny, nx)) -da.coords['lon'] = (('y', 'x'), lon) -da.coords['lat'] = (('y', 'x'), lat) +da.coords["lon"] = (("y", "x"), lon) +da.coords["lat"] = (("y", "x"), lat) # Compute a greyscale out of the rgb image -greyscale = da.mean(dim='band') +greyscale = da.mean(dim="band") # Plot on a map ax = plt.subplot(projection=ccrs.PlateCarree()) -greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(), - cmap='Greys_r', add_colorbar=False) -ax.coastlines('10m', color='r') +greyscale.plot( + ax=ax, + x="lon", + y="lat", + transform=ccrs.PlateCarree(), + cmap="Greys_r", + add_colorbar=False, +) +ax.coastlines("10m", color="r") plt.show() diff --git a/doc/gallery/plot_rasterio_rgb.py b/doc/gallery/plot_rasterio_rgb.py index 23a56d5a291..4b5b30ea793 100644 --- a/doc/gallery/plot_rasterio_rgb.py +++ b/doc/gallery/plot_rasterio_rgb.py @@ -19,15 +19,15 @@ import xarray as xr # Read the data -url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif' +url = "https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif" da = xr.open_rasterio(url) # The data is in UTM projection. We have to set it manually until # https://github.com/SciTools/cartopy/issues/813 is implemented -crs = ccrs.UTM('18N') +crs = ccrs.UTM("18N") # Plot on a map ax = plt.subplot(projection=crs) -da.plot.imshow(ax=ax, rgb='band', transform=crs) -ax.coastlines('10m', color='r') +da.plot.imshow(ax=ax, rgb="band", transform=crs) +ax.coastlines("10m", color="r") plt.show() diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 4b9aa8928b4..b8f52e3de7a 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -17,9 +17,7 @@ an_array = npst.arrays( dtype=st.one_of( - npst.unsigned_integer_dtypes(), - npst.integer_dtypes(), - npst.floating_dtypes(), + npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() ), shape=npst.array_shapes(max_side=3), # max_side specified for performance ) @@ -27,8 +25,11 @@ @given(st.data(), an_array) def test_CFMask_coder_roundtrip(data, arr): - names = data.draw(st.lists(st.text(), min_size=arr.ndim, - max_size=arr.ndim, unique=True).map(tuple)) + names = data.draw( + st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( + tuple + ) + ) original = xr.Variable(names, arr) coder = xr.coding.variables.CFMaskCoder() roundtripped = coder.decode(coder.encode(original)) @@ -37,8 +38,11 @@ def test_CFMask_coder_roundtrip(data, arr): @given(st.data(), an_array) def test_CFScaleOffset_coder_roundtrip(data, arr): - names = data.draw(st.lists(st.text(), min_size=arr.ndim, - max_size=arr.ndim, unique=True).map(tuple)) + names = data.draw( + st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( + tuple + ) + ) original = xr.Variable(names, arr) coder = xr.coding.variables.CFScaleOffsetCoder() roundtripped = coder.decode(coder.encode(original)) diff --git a/setup.cfg b/setup.cfg index 128550071cc..68fc3b6e06c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,22 +15,26 @@ markers = slow: slow tests [flake8] -max-line-length=79 +max-line-length=88 ignore= + # whitespace before ':' - doesn't work well with black + E203 E402 + # do not assign a lambda expression, use a def E731 - E741 + # line break before binary operator W503 - W504 # Unused imports; TODO: Allow typing to work without triggering errors F401 exclude= doc [isort] -default_section=THIRDPARTY -known_first_party=xarray -multi_line_output=4 +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88 # Most of the numerical computing stack doesn't have type annotations yet. [mypy-affine.*] diff --git a/setup.py b/setup.py index 977ad2e1bd8..b829f6e1f98 100644 --- a/setup.py +++ b/setup.py @@ -4,29 +4,29 @@ import versioneer from setuptools import find_packages, setup -DISTNAME = 'xarray' -LICENSE = 'Apache' -AUTHOR = 'xarray Developers' -AUTHOR_EMAIL = 'xarray@googlegroups.com' -URL = 'https://github.com/pydata/xarray' +DISTNAME = "xarray" +LICENSE = "Apache" +AUTHOR = "xarray Developers" +AUTHOR_EMAIL = "xarray@googlegroups.com" +URL = "https://github.com/pydata/xarray" CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Intended Audience :: Science/Research', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Topic :: Scientific/Engineering', + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Topic :: Scientific/Engineering", ] -PYTHON_REQUIRES = '>=3.5.3' -INSTALL_REQUIRES = ['numpy >= 1.12', 'pandas >= 0.19.2'] -needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) -SETUP_REQUIRES = ['pytest-runner >= 4.2'] if needs_pytest else [] -TESTS_REQUIRE = ['pytest >= 2.7.1'] +PYTHON_REQUIRES = ">=3.5.3" +INSTALL_REQUIRES = ["numpy >= 1.12", "pandas >= 0.19.2"] +needs_pytest = {"pytest", "test", "ptr"}.intersection(sys.argv) +SETUP_REQUIRES = ["pytest-runner >= 4.2"] if needs_pytest else [] +TESTS_REQUIRE = ["pytest >= 2.7.1"] DESCRIPTION = "N-D labeled arrays and datasets in Python" LONG_DESCRIPTION = """ @@ -89,19 +89,21 @@ """ # noqa -setup(name=DISTNAME, - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - license=LICENSE, - author=AUTHOR, - author_email=AUTHOR_EMAIL, - classifiers=CLASSIFIERS, - description=DESCRIPTION, - long_description=LONG_DESCRIPTION, - python_requires=PYTHON_REQUIRES, - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES, - tests_require=TESTS_REQUIRE, - url=URL, - packages=find_packages(), - package_data={'xarray': ['py.typed', 'tests/data/*']}) +setup( + name=DISTNAME, + version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), + license=LICENSE, + author=AUTHOR, + author_email=AUTHOR_EMAIL, + classifiers=CLASSIFIERS, + description=DESCRIPTION, + long_description=LONG_DESCRIPTION, + python_requires=PYTHON_REQUIRES, + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, + tests_require=TESTS_REQUIRE, + url=URL, + packages=find_packages(), + package_data={"xarray": ["py.typed", "tests/data/*"]}, +) diff --git a/versioneer.py b/versioneer.py index e369108b439..ea714e561b7 100644 --- a/versioneer.py +++ b/versioneer.py @@ -311,11 +311,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -328,8 +330,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py) + ) except NameError: pass return root @@ -351,6 +355,7 @@ def get(parser, name): if parser.has_option("versioneer", name): return parser.get("versioneer", name) return None + cfg = VersioneerConfig() cfg.VCS = VCS cfg.style = get(parser, "style") or "" @@ -375,17 +380,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -393,10 +399,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except OSError: e = sys.exc_info()[1] @@ -421,7 +430,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, p.returncode -LONG_VERSION_PY['git'] = r''' +LONG_VERSION_PY[ + "git" +] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -996,7 +1007,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1005,7 +1016,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1013,19 +1024,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1040,8 +1058,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1049,10 +1066,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1075,17 +1101,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1094,10 +1119,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1108,13 +1135,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -1170,16 +1197,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1208,11 +1241,13 @@ def versions_from_file(filename): contents = f.read() except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1221,8 +1256,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1254,8 +1288,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1369,11 +1402,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1393,9 +1428,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1418,8 +1457,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1473,9 +1513,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1524,6 +1568,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools @@ -1556,14 +1601,15 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1584,17 +1630,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.distutils_buildexe import py2exe as _py2exe # py3 except ImportError: @@ -1613,13 +1663,17 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments @@ -1646,8 +1700,10 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -1704,8 +1760,7 @@ def do_setup(): cfg = get_config_from_root(root) except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1714,15 +1769,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: @@ -1764,8 +1822,10 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print( + " appending versionfile_source ('%s') to MANIFEST.in" + % cfg.versionfile_source + ) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: diff --git a/xarray/__init__.py b/xarray/__init__.py index c2b78fe9dd4..a3df034f7c7 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -2,7 +2,8 @@ # flake8: noqa from ._version import get_versions -__version__ = get_versions()['version'] + +__version__ = get_versions()["version"] del get_versions from .core.alignment import align, broadcast, broadcast_arrays @@ -10,16 +11,21 @@ from .core.concat import concat from .core.combine import combine_by_coords, combine_nested, auto_combine from .core.computation import apply_ufunc, dot, where -from .core.extensions import (register_dataarray_accessor, - register_dataset_accessor) +from .core.extensions import register_dataarray_accessor, register_dataset_accessor from .core.variable import as_variable, Variable, IndexVariable, Coordinate from .core.dataset import Dataset from .core.dataarray import DataArray from .core.merge import merge, MergeError from .core.options import set_options -from .backends.api import (open_dataset, open_dataarray, open_mfdataset, - save_mfdataset, load_dataset, load_dataarray) +from .backends.api import ( + open_dataset, + open_dataarray, + open_mfdataset, + save_mfdataset, + load_dataset, + load_dataarray, +) from .backends.rasterio_ import open_rasterio from .backends.zarr import open_zarr diff --git a/xarray/_version.py b/xarray/_version.py index 442e56a04b0..826bf470ca7 100644 --- a/xarray/_version.py +++ b/xarray/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,17 +57,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except OSError: e = sys.exc_info()[1] @@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -330,8 +355,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -445,11 +469,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -469,9 +495,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -485,8 +515,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -495,13 +524,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for i in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -515,6 +547,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 292a6d68523..2a769b1335e 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -16,17 +16,17 @@ from .zarr import ZarrStore __all__ = [ - 'AbstractDataStore', - 'FileManager', - 'CachingFileManager', - 'CfGribDataStore', - 'DummyFileManager', - 'InMemoryDataStore', - 'NetCDF4DataStore', - 'PydapDataStore', - 'NioDataStore', - 'ScipyDataStore', - 'H5NetCDFStore', - 'ZarrStore', - 'PseudoNetCDFDataStore', + "AbstractDataStore", + "FileManager", + "CachingFileManager", + "CfGribDataStore", + "DummyFileManager", + "InMemoryDataStore", + "NetCDF4DataStore", + "PydapDataStore", + "NioDataStore", + "ScipyDataStore", + "H5NetCDFStore", + "ZarrStore", + "PseudoNetCDFDataStore", ] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 292373d2a33..8d5f7f05a9f 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -5,8 +5,16 @@ from numbers import Number from pathlib import Path from textwrap import dedent -from typing import (Callable, Dict, Hashable, Iterable, Mapping, Tuple, Union, - TYPE_CHECKING) +from typing import ( + Callable, + Dict, + Hashable, + Iterable, + Mapping, + Tuple, + Union, + TYPE_CHECKING, +) import numpy as np @@ -16,7 +24,7 @@ from ..core.combine import ( combine_by_coords, _nested_combine, - _infer_concat_order_from_positions + _infer_concat_order_from_positions, ) from ..core.utils import close_on_error, is_grib_path, is_remote_uri from .common import ArrayWriter, AbstractDataStore @@ -29,21 +37,25 @@ Delayed = None -DATAARRAY_NAME = '__xarray_dataarray_name__' -DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' +DATAARRAY_NAME = "__xarray_dataarray_name__" +DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" def _get_default_engine_remote_uri(): try: import netCDF4 # noqa - engine = 'netcdf4' + + engine = "netcdf4" except ImportError: # pragma: no cover try: import pydap # noqa - engine = 'pydap' + + engine = "pydap" except ImportError: - raise ValueError('netCDF4 or pydap is required for accessing ' - 'remote datasets via OPeNDAP') + raise ValueError( + "netCDF4 or pydap is required for accessing " + "remote datasets via OPeNDAP" + ) return engine @@ -51,41 +63,47 @@ def _get_default_engine_grib(): msgs = [] try: import Nio # noqa + msgs += ["set engine='pynio' to access GRIB files with PyNIO"] except ImportError: # pragma: no cover pass try: import cfgrib # noqa + msgs += ["set engine='cfgrib' to access GRIB files with cfgrib"] except ImportError: # pragma: no cover pass if msgs: - raise ValueError(' or\n'.join(msgs)) + raise ValueError(" or\n".join(msgs)) else: - raise ValueError('PyNIO or cfgrib is required for accessing ' - 'GRIB files') + raise ValueError("PyNIO or cfgrib is required for accessing " "GRIB files") def _get_default_engine_gz(): try: import scipy # noqa - engine = 'scipy' + + engine = "scipy" except ImportError: # pragma: no cover - raise ValueError('scipy is required for accessing .gz files') + raise ValueError("scipy is required for accessing .gz files") return engine def _get_default_engine_netcdf(): try: import netCDF4 # noqa - engine = 'netcdf4' + + engine = "netcdf4" except ImportError: # pragma: no cover try: import scipy.io.netcdf # noqa - engine = 'scipy' + + engine = "scipy" except ImportError: - raise ValueError('cannot read or write netCDF files without ' - 'netCDF4-python or scipy installed') + raise ValueError( + "cannot read or write netCDF files without " + "netCDF4-python or scipy installed" + ) return engine @@ -95,25 +113,30 @@ def _get_engine_from_magic_number(filename_or_obj): magic_number = filename_or_obj[:8] else: if filename_or_obj.tell() != 0: - raise ValueError("file-like object read/write pointer not at zero " - "please close and reopen, or use a context " - "manager") + raise ValueError( + "file-like object read/write pointer not at zero " + "please close and reopen, or use a context " + "manager" + ) magic_number = filename_or_obj.read(8) filename_or_obj.seek(0) - if magic_number.startswith(b'CDF'): - engine = 'scipy' - elif magic_number.startswith(b'\211HDF\r\n\032\n'): - engine = 'h5netcdf' + if magic_number.startswith(b"CDF"): + engine = "scipy" + elif magic_number.startswith(b"\211HDF\r\n\032\n"): + engine = "h5netcdf" if isinstance(filename_or_obj, bytes): - raise ValueError("can't open netCDF4/HDF5 as bytes " - "try passing a path or file-like object") + raise ValueError( + "can't open netCDF4/HDF5 as bytes " + "try passing a path or file-like object" + ) else: if isinstance(filename_or_obj, bytes) and len(filename_or_obj) > 80: - filename_or_obj = filename_or_obj[:80] + b'...' - raise ValueError('{} is not a valid netCDF file ' - 'did you mean to pass a string for a path instead?' - .format(filename_or_obj)) + filename_or_obj = filename_or_obj[:80] + b"..." + raise ValueError( + "{} is not a valid netCDF file " + "did you mean to pass a string for a path instead?".format(filename_or_obj) + ) return engine @@ -122,7 +145,7 @@ def _get_default_engine(path, allow_remote=False): engine = _get_default_engine_remote_uri() elif is_grib_path(path): engine = _get_default_engine_grib() - elif path.endswith('.gz'): + elif path.endswith(".gz"): engine = _get_default_engine_gz() else: engine = _get_default_engine_netcdf() @@ -138,15 +161,20 @@ def _normalize_path(path): def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" + def check_name(name): if isinstance(name, str): if not name: - raise ValueError('Invalid name for DataArray or Dataset key: ' - 'string must be length 1 or greater for ' - 'serialization to netCDF files') + raise ValueError( + "Invalid name for DataArray or Dataset key: " + "string must be length 1 or greater for " + "serialization to netCDF files" + ) elif name is not None: - raise TypeError('DataArray.name or Dataset key must be either a ' - 'string or None for serialization to netCDF files') + raise TypeError( + "DataArray.name or Dataset key must be either a " + "string or None for serialization to netCDF files" + ) for k in dataset.variables: check_name(k) @@ -156,22 +184,28 @@ def _validate_attrs(dataset): """`attrs` must have a string key and a value which is either: a number, a string, an ndarray or a list/tuple of numbers/strings. """ + def check_attr(name, value): if isinstance(name, str): if not name: - raise ValueError('Invalid name for attr: string must be ' - 'length 1 or greater for serialization to ' - 'netCDF files') + raise ValueError( + "Invalid name for attr: string must be " + "length 1 or greater for serialization to " + "netCDF files" + ) else: - raise TypeError("Invalid name for attr: {} must be a string for " - "serialization to netCDF files".format(name)) + raise TypeError( + "Invalid name for attr: {} must be a string for " + "serialization to netCDF files".format(name) + ) - if not isinstance(value, (str, Number, np.ndarray, np.number, - list, tuple)): - raise TypeError('Invalid value for attr: {} must be a number, ' - 'a string, an ndarray or a list/tuple of ' - 'numbers/strings for serialization to netCDF ' - 'files'.format(value)) + if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): + raise TypeError( + "Invalid value for attr: {} must be a number, " + "a string, an ndarray or a list/tuple of " + "numbers/strings for serialization to netCDF " + "files".format(value) + ) # Check attrs on the dataset itself for k, v in dataset.attrs.items(): @@ -218,8 +252,8 @@ def load_dataset(filename_or_obj, **kwargs): -------- open_dataset """ - if 'cache' in kwargs: - raise TypeError('cache has no effect in this context') + if "cache" in kwargs: + raise TypeError("cache has no effect in this context") with open_dataset(filename_or_obj, **kwargs) as ds: return ds.load() @@ -244,18 +278,30 @@ def load_dataarray(filename_or_obj, **kwargs): -------- open_dataarray """ - if 'cache' in kwargs: - raise TypeError('cache has no effect in this context') + if "cache" in kwargs: + raise TypeError("cache has no effect in this context") with open_dataarray(filename_or_obj, **kwargs) as da: return da.load() -def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=None, - concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, cache=None, drop_variables=None, - backend_kwargs=None, use_cftime=None): +def open_dataset( + filename_or_obj, + group=None, + decode_cf=True, + mask_and_scale=None, + decode_times=True, + autoclose=None, + concat_characters=True, + decode_coords=True, + engine=None, + chunks=None, + lock=None, + cache=None, + drop_variables=None, + backend_kwargs=None, + use_cftime=None, +): """Open and decode a dataset from a file or file-like object. Parameters @@ -352,24 +398,35 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ - engines = [None, 'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', - 'cfgrib', 'pseudonetcdf'] + engines = [ + None, + "netcdf4", + "scipy", + "pydap", + "h5netcdf", + "pynio", + "cfgrib", + "pseudonetcdf", + ] if engine not in engines: - raise ValueError('unrecognized engine for open_dataset: {}\n' - 'must be one of: {}' - .format(engine, engines)) + raise ValueError( + "unrecognized engine for open_dataset: {}\n" + "must be one of: {}".format(engine, engines) + ) if autoclose is not None: warnings.warn( - 'The autoclose argument is no longer used by ' - 'xarray.open_dataset() and is now ignored; it will be removed in ' - 'a future version of xarray. If necessary, you can control the ' - 'maximum number of simultaneous open files with ' - 'xarray.set_options(file_cache_maxsize=...).', - FutureWarning, stacklevel=2) + "The autoclose argument is no longer used by " + "xarray.open_dataset() and is now ignored; it will be removed in " + "a future version of xarray. If necessary, you can control the " + "maximum number of simultaneous open files with " + "xarray.set_options(file_cache_maxsize=...).", + FutureWarning, + stacklevel=2, + ) if mask_and_scale is None: - mask_and_scale = not engine == 'pseudonetcdf' + mask_and_scale = not engine == "pseudonetcdf" if not decode_cf: mask_and_scale = False @@ -385,26 +442,41 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( - store, mask_and_scale=mask_and_scale, decode_times=decode_times, - concat_characters=concat_characters, decode_coords=decode_coords, - drop_variables=drop_variables, use_cftime=use_cftime) + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + ) _protect_dataset_variables_inplace(ds, cache) if chunks is not None: from dask.base import tokenize + # if passed an actual file path, augment the token with # the file modification time - if (isinstance(filename_or_obj, str) and - not is_remote_uri(filename_or_obj)): + if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): mtime = os.path.getmtime(filename_or_obj) else: mtime = None - token = tokenize(filename_or_obj, mtime, group, decode_cf, - mask_and_scale, decode_times, concat_characters, - decode_coords, engine, chunks, drop_variables, - use_cftime) - name_prefix = 'open_dataset-%s' % token + token = tokenize( + filename_or_obj, + mtime, + group, + decode_cf, + mask_and_scale, + decode_times, + concat_characters, + decode_coords, + engine, + chunks, + drop_variables, + use_cftime, + ) + name_prefix = "open_dataset-%s" % token ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: @@ -422,56 +494,72 @@ def maybe_decode_store(store, lock=False): filename_or_obj = _normalize_path(filename_or_obj) if engine is None: - engine = _get_default_engine(filename_or_obj, - allow_remote=True) - if engine == 'netcdf4': + engine = _get_default_engine(filename_or_obj, allow_remote=True) + if engine == "netcdf4": store = backends.NetCDF4DataStore.open( - filename_or_obj, group=group, lock=lock, **backend_kwargs) - elif engine == 'scipy': + filename_or_obj, group=group, lock=lock, **backend_kwargs + ) + elif engine == "scipy": store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) - elif engine == 'pydap': - store = backends.PydapDataStore.open( - filename_or_obj, **backend_kwargs) - elif engine == 'h5netcdf': + elif engine == "pydap": + store = backends.PydapDataStore.open(filename_or_obj, **backend_kwargs) + elif engine == "h5netcdf": store = backends.H5NetCDFStore( - filename_or_obj, group=group, lock=lock, **backend_kwargs) - elif engine == 'pynio': - store = backends.NioDataStore( - filename_or_obj, lock=lock, **backend_kwargs) - elif engine == 'pseudonetcdf': + filename_or_obj, group=group, lock=lock, **backend_kwargs + ) + elif engine == "pynio": + store = backends.NioDataStore(filename_or_obj, lock=lock, **backend_kwargs) + elif engine == "pseudonetcdf": store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, lock=lock, **backend_kwargs) - elif engine == 'cfgrib': + filename_or_obj, lock=lock, **backend_kwargs + ) + elif engine == "cfgrib": store = backends.CfGribDataStore( - filename_or_obj, lock=lock, **backend_kwargs) + filename_or_obj, lock=lock, **backend_kwargs + ) else: - if engine not in [None, 'scipy', 'h5netcdf']: - raise ValueError("can only read bytes or file-like objects " - "with engine='scipy' or 'h5netcdf'") + if engine not in [None, "scipy", "h5netcdf"]: + raise ValueError( + "can only read bytes or file-like objects " + "with engine='scipy' or 'h5netcdf'" + ) engine = _get_engine_from_magic_number(filename_or_obj) - if engine == 'scipy': + if engine == "scipy": store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) - elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - lock=lock, **backend_kwargs) + elif engine == "h5netcdf": + store = backends.H5NetCDFStore( + filename_or_obj, group=group, lock=lock, **backend_kwargs + ) with close_on_error(store): ds = maybe_decode_store(store) # Ensure source filename always stored in dataset object (GH issue #2550) - if 'source' not in ds.encoding: + if "source" not in ds.encoding: if isinstance(filename_or_obj, str): - ds.encoding['source'] = filename_or_obj + ds.encoding["source"] = filename_or_obj return ds -def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=None, - concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, cache=None, drop_variables=None, - backend_kwargs=None, use_cftime=None): +def open_dataarray( + filename_or_obj, + group=None, + decode_cf=True, + mask_and_scale=None, + decode_times=True, + autoclose=None, + concat_characters=True, + decode_coords=True, + engine=None, + chunks=None, + lock=None, + cache=None, + drop_variables=None, + backend_kwargs=None, + use_cftime=None, +): """Open an DataArray from a file or file-like object containing a single data variable. @@ -565,20 +653,30 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, open_dataset """ - dataset = open_dataset(filename_or_obj, group=group, decode_cf=decode_cf, - mask_and_scale=mask_and_scale, - decode_times=decode_times, autoclose=autoclose, - concat_characters=concat_characters, - decode_coords=decode_coords, engine=engine, - chunks=chunks, lock=lock, cache=cache, - drop_variables=drop_variables, - backend_kwargs=backend_kwargs, - use_cftime=use_cftime) + dataset = open_dataset( + filename_or_obj, + group=group, + decode_cf=decode_cf, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + autoclose=autoclose, + concat_characters=concat_characters, + decode_coords=decode_coords, + engine=engine, + chunks=chunks, + lock=lock, + cache=cache, + drop_variables=drop_variables, + backend_kwargs=backend_kwargs, + use_cftime=use_cftime, + ) if len(dataset.data_vars) != 1: - raise ValueError('Given file dataset contains more than one data ' - 'variable. Please read with xarray.open_dataset and ' - 'then select the variable you want.') + raise ValueError( + "Given file dataset contains more than one data " + "variable. Please read with xarray.open_dataset and " + "then select the variable you want." + ) else: data_array, = dataset.data_vars.values() @@ -605,11 +703,22 @@ def close(self): f.close() -def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', - compat='no_conflicts', preprocess=None, engine=None, - lock=None, data_vars='all', coords='different', - combine='_old_auto', autoclose=None, parallel=False, - join='outer', **kwargs): +def open_mfdataset( + paths, + chunks=None, + concat_dim="_not_supplied", + compat="no_conflicts", + preprocess=None, + engine=None, + lock=None, + data_vars="all", + coords="different", + combine="_old_auto", + autoclose=None, + parallel=False, + join="outer", + **kwargs +): """Open multiple files as a single dataset. If combine='by_coords' then the function ``combine_by_coords`` is used to @@ -744,35 +853,37 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', if isinstance(paths, str): if is_remote_uri(paths): raise ValueError( - 'cannot do wild-card matching for paths that are remote URLs: ' - '{!r}. Instead, supply paths as an explicit list of strings.' - .format(paths)) + "cannot do wild-card matching for paths that are remote URLs: " + "{!r}. Instead, supply paths as an explicit list of strings.".format( + paths + ) + ) paths = sorted(glob(paths)) else: paths = [str(p) if isinstance(p, Path) else p for p in paths] if not paths: - raise OSError('no files to open') + raise OSError("no files to open") # If combine='by_coords' then this is unnecessary, but quick. # If combine='nested' then this creates a flat list which is easier to # iterate over, while saving the originally-supplied structure as "ids" - if combine == 'nested': - if str(concat_dim) == '_not_supplied': - raise ValueError("Must supply concat_dim when using " - "combine='nested'") + if combine == "nested": + if str(concat_dim) == "_not_supplied": + raise ValueError("Must supply concat_dim when using " "combine='nested'") else: if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: concat_dim = [concat_dim] combined_ids_paths = _infer_concat_order_from_positions(paths) - ids, paths = ( - list(combined_ids_paths.keys()), list(combined_ids_paths.values())) + ids, paths = (list(combined_ids_paths.keys()), list(combined_ids_paths.values())) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, - autoclose=autoclose, **kwargs) + open_kwargs = dict( + engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs + ) if parallel: import dask + # wrap the open_dataset, getattr, and preprocess with delayed open_ = dask.delayed(open_dataset) getattr_ = dask.delayed(getattr) @@ -783,7 +894,7 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', getattr_ = getattr datasets = [open_(p, **open_kwargs) for p in paths] - file_objs = [getattr_(ds, '_file_obj') for ds in datasets] + file_objs = [getattr_(ds, "_file_obj") for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] @@ -794,37 +905,52 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', # Combine all datasets, closing them in case of a ValueError try: - if combine == '_old_auto': + if combine == "_old_auto": # Use the old auto_combine for now # Remove this after deprecation cycle from #2616 is complete - basic_msg = dedent("""\ + basic_msg = dedent( + """\ In xarray version 0.13 the default behaviour of `open_mfdataset` will change. To retain the existing behavior, pass combine='nested'. To use future default behavior, pass combine='by_coords'. See http://xarray.pydata.org/en/stable/combining.html#combining-multi - """) + """ + ) warnings.warn(basic_msg, FutureWarning, stacklevel=2) - combined = auto_combine(datasets, concat_dim=concat_dim, - compat=compat, data_vars=data_vars, - coords=coords, join=join, - from_openmfds=True) - elif combine == 'nested': + combined = auto_combine( + datasets, + concat_dim=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + join=join, + from_openmfds=True, + ) + elif combine == "nested": # Combined nested list by successive concat and merge operations # along each dimension, using structure given by "ids" - combined = _nested_combine(datasets, concat_dims=concat_dim, - compat=compat, data_vars=data_vars, - coords=coords, ids=ids, join=join) - elif combine == 'by_coords': + combined = _nested_combine( + datasets, + concat_dims=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + ids=ids, + join=join, + ) + elif combine == "by_coords": # Redo ordering from coordinates, ignoring how they were ordered # previously - combined = combine_by_coords(datasets, compat=compat, - data_vars=data_vars, coords=coords, - join=join) + combined = combine_by_coords( + datasets, compat=compat, data_vars=data_vars, coords=coords, join=join + ) else: - raise ValueError("{} is an invalid option for the keyword argument" - " ``combine``".format(combine)) + raise ValueError( + "{} is an invalid option for the keyword argument" + " ``combine``".format(combine) + ) except ValueError: for ds in datasets: ds.close() @@ -836,24 +962,24 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied', WRITEABLE_STORES = { - 'netcdf4': backends.NetCDF4DataStore.open, - 'scipy': backends.ScipyDataStore, - 'h5netcdf': backends.H5NetCDFStore + "netcdf4": backends.NetCDF4DataStore.open, + "scipy": backends.ScipyDataStore, + "h5netcdf": backends.H5NetCDFStore, } # type: Dict[str, Callable] def to_netcdf( dataset: Dataset, path_or_file=None, - mode: str = 'w', + mode: str = "w", format: str = None, group: str = None, engine: str = None, encoding: Mapping = None, unlimited_dims: Iterable[Hashable] = None, compute: bool = True, - multifile: bool = False -) -> Union[Tuple[ArrayWriter, AbstractDataStore], bytes, 'Delayed', None]: + multifile: bool = False, +) -> Union[Tuple[ArrayWriter, AbstractDataStore], bytes, "Delayed", None]: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -869,21 +995,24 @@ def to_netcdf( if path_or_file is None: if engine is None: - engine = 'scipy' - elif engine != 'scipy': - raise ValueError('invalid engine for creating bytes with ' - 'to_netcdf: %r. Only the default engine ' - "or engine='scipy' is supported" % engine) + engine = "scipy" + elif engine != "scipy": + raise ValueError( + "invalid engine for creating bytes with " + "to_netcdf: %r. Only the default engine " + "or engine='scipy' is supported" % engine + ) if not compute: raise NotImplementedError( - 'to_netcdf() with compute=False is not yet implemented when ' - 'returning bytes') + "to_netcdf() with compute=False is not yet implemented when " + "returning bytes" + ) elif isinstance(path_or_file, str): if engine is None: engine = _get_default_engine(path_or_file) path_or_file = _normalize_path(path_or_file) else: # file-like object - engine = 'scipy' + engine = "scipy" # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) @@ -892,7 +1021,7 @@ def to_netcdf( try: store_open = WRITEABLE_STORES[engine] except KeyError: - raise ValueError('unrecognized engine for to_netcdf: %r' % engine) + raise ValueError("unrecognized engine for to_netcdf: %r" % engine) if format is not None: format = format.upper() @@ -901,21 +1030,22 @@ def to_netcdf( scheduler = _get_scheduler() have_chunks = any(v.chunks for v in dataset.variables.values()) - autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] - if autoclose and engine == 'scipy': - raise NotImplementedError("Writing netCDF files with the %s backend " - "is not currently supported with dask's %s " - "scheduler" % (engine, scheduler)) + autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] + if autoclose and engine == "scipy": + raise NotImplementedError( + "Writing netCDF files with the %s backend " + "is not currently supported with dask's %s " + "scheduler" % (engine, scheduler) + ) target = path_or_file if path_or_file is not None else BytesIO() kwargs = dict(autoclose=True) if autoclose else {} store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: - unlimited_dims = dataset.encoding.get('unlimited_dims', None) + unlimited_dims = dataset.encoding.get("unlimited_dims", None) if unlimited_dims is not None: - if (isinstance(unlimited_dims, str) - or not isinstance(unlimited_dims, Iterable)): + if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): unlimited_dims = [unlimited_dims] else: unlimited_dims = list(unlimited_dims) @@ -927,8 +1057,9 @@ def to_netcdf( try: # TODO: allow this work (setting up the file for writing array data) # to be parallelized with dask - dump_to_store(dataset, store, writer, encoding=encoding, - unlimited_dims=unlimited_dims) + dump_to_store( + dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims + ) if autoclose: store.close() @@ -946,12 +1077,14 @@ def to_netcdf( if not compute: import dask + return dask.delayed(_finalize_store)(writes, store) return None -def dump_to_store(dataset, store, writer=None, encoder=None, - encoding=None, unlimited_dims=None): +def dump_to_store( + dataset, store, writer=None, encoder=None, encoding=None, unlimited_dims=None +): """Store dataset contents to a backends.*DataStore object.""" if writer is None: writer = ArrayWriter() @@ -971,12 +1104,12 @@ def dump_to_store(dataset, store, writer=None, encoder=None, if encoder: variables, attrs = encoder(variables, attrs) - store.store(variables, attrs, check_encoding, writer, - unlimited_dims=unlimited_dims) + store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims) -def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, - engine=None, compute=True): +def save_mfdataset( + datasets, paths, mode="w", format=None, groups=None, engine=None, compute=True +): """Write multiple datasets to disk as netCDF files simultaneously. This function is intended for use with datasets consisting of dask.array @@ -1039,27 +1172,36 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, >>> paths = ['%s.nc' % y for y in years] >>> xr.save_mfdataset(datasets, paths) """ - if mode == 'w' and len(set(paths)) < len(paths): - raise ValueError("cannot use mode='w' when writing multiple " - 'datasets to the same path') + if mode == "w" and len(set(paths)) < len(paths): + raise ValueError( + "cannot use mode='w' when writing multiple " "datasets to the same path" + ) for obj in datasets: if not isinstance(obj, Dataset): - raise TypeError('save_mfdataset only supports writing Dataset ' - 'objects, received type %s' % type(obj)) + raise TypeError( + "save_mfdataset only supports writing Dataset " + "objects, received type %s" % type(obj) + ) if groups is None: groups = [None] * len(datasets) if len({len(datasets), len(paths), len(groups)}) > 1: - raise ValueError('must supply lists of the same length for the ' - 'datasets, paths and groups arguments to ' - 'save_mfdataset') - - writers, stores = zip(*[ - to_netcdf(ds, path, mode, format, group, engine, compute=compute, - multifile=True) - for ds, path, group in zip(datasets, paths, groups)]) + raise ValueError( + "must supply lists of the same length for the " + "datasets, paths and groups arguments to " + "save_mfdataset" + ) + + writers, stores = zip( + *[ + to_netcdf( + ds, path, mode, format, group, engine, compute=compute, multifile=True + ) + for ds, path, group in zip(datasets, paths, groups) + ] + ) try: writes = [w.sync(compute=compute) for w in writers] @@ -1070,27 +1212,36 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, if not compute: import dask - return dask.delayed([dask.delayed(_finalize_store)(w, s) - for w, s in zip(writes, stores)]) + + return dask.delayed( + [dask.delayed(_finalize_store)(w, s) for w, s in zip(writes, stores)] + ) def _validate_datatypes_for_zarr_append(dataset): """DataArray.name and Dataset keys must be a string or None""" + def check_dtype(var): - if (not np.issubdtype(var.dtype, np.number) - and not coding.strings.is_unicode_dtype(var.dtype) - and not var.dtype == object): + if ( + not np.issubdtype(var.dtype, np.number) + and not coding.strings.is_unicode_dtype(var.dtype) + and not var.dtype == object + ): # and not re.match('^bytes[1-9]+$', var.dtype.name)): - raise ValueError('Invalid dtype for data variable: {} ' - 'dtype must be a subtype of number, ' - 'a fixed sized string, a fixed size ' - 'unicode string or an object'.format(var)) + raise ValueError( + "Invalid dtype for data variable: {} " + "dtype must be a subtype of number, " + "a fixed sized string, a fixed size " + "unicode string or an object".format(var) + ) + for k in dataset.data_vars.values(): check_dtype(k) -def _validate_append_dim_and_encoding(ds_to_append, store, append_dim, - encoding, **open_kwargs): +def _validate_append_dim_and_encoding( + ds_to_append, store, append_dim, encoding, **open_kwargs +): try: ds = backends.zarr.open_zarr(store, **open_kwargs) except ValueError: # store empty @@ -1114,8 +1265,17 @@ def _validate_append_dim_and_encoding(ds_to_append, store, append_dim, ) -def to_zarr(dataset, store=None, mode=None, synchronizer=None, group=None, - encoding=None, compute=True, consolidated=False, append_dim=None): +def to_zarr( + dataset, + store=None, + mode=None, + synchronizer=None, + group=None, + encoding=None, + compute=True, + consolidated=False, + append_dim=None, +): """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -1130,17 +1290,24 @@ def to_zarr(dataset, store=None, mode=None, synchronizer=None, group=None, _validate_dataset_names(dataset) _validate_attrs(dataset) - if mode == 'a': + if mode == "a": _validate_datatypes_for_zarr_append(dataset) - _validate_append_dim_and_encoding(dataset, store, append_dim, - group=group, - consolidated=consolidated, - encoding=encoding) - - zstore = backends.ZarrStore.open_group(store=store, mode=mode, - synchronizer=synchronizer, - group=group, - consolidate_on_close=consolidated) + _validate_append_dim_and_encoding( + dataset, + store, + append_dim, + group=group, + consolidated=consolidated, + encoding=encoding, + ) + + zstore = backends.ZarrStore.open_group( + store=store, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidate_on_close=consolidated, + ) zstore.append_dim = append_dim writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims @@ -1151,6 +1318,7 @@ def to_zarr(dataset, store=None, mode=None, synchronizer=None, group=None, _finalize_store(writes, zstore) else: import dask + return dask.delayed(_finalize_store)(writes, zstore) return zstore diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 51c3318e794..108406ee183 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -21,7 +21,8 @@ def __init__(self, datastore, array): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + ) def _getitem(self, key): with self.datastore.lock: @@ -32,8 +33,10 @@ class CfGribDataStore(AbstractDataStore): """ Implements the ``xr.AbstractDataStore`` read-only API for a GRIB file. """ + def __init__(self, filename, lock=None, **backend_kwargs): import cfgrib + if lock is None: lock = ECCODES_LOCK self.lock = ensure_lock(lock) @@ -47,13 +50,14 @@ def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(wrapped_array) encoding = self.ds.encoding.copy() - encoding['original_shape'] = var.data.shape + encoding["original_shape"] = var.data.shape return Variable(var.dimensions, data, var.attributes, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) def get_attrs(self): return Frozen(self.ds.attributes) @@ -63,7 +67,5 @@ def get_dimensions(self): def get_encoding(self): dims = self.get_dimensions() - encoding = { - 'unlimited_dims': {k for k, v in dims.items() if v is None}, - } + encoding = {"unlimited_dims": {k for k, v in dims.items() if v is None}} return encoding diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 7096bdc826c..7ee11052192 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -NONE_VAR_NAME = '__values__' +NONE_VAR_NAME = "__values__" def _encode_variable_name(name): @@ -37,12 +37,11 @@ def find_root_and_group(ds): while ds.parent is not None: hierarchy = (ds.name,) + hierarchy ds = ds.parent - group = '/' + '/'.join(hierarchy) + group = "/" + "/".join(hierarchy) return ds, group -def robust_getitem(array, key, catch=Exception, max_retries=6, - initial_delay=500): +def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500): """ Robustly index an array, using retry logic with exponential backoff if any of the errors ``catch`` are raised. The initial_delay is measured in ms. @@ -59,22 +58,22 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, raise base_delay = initial_delay * 2 ** n next_delay = base_delay + np.random.randint(base_delay) - msg = ('getitem failed, waiting %s ms before trying again ' - '(%s tries remaining). Full traceback: %s' % - (next_delay, max_retries - n, traceback.format_exc())) + msg = ( + "getitem failed, waiting %s ms before trying again " + "(%s tries remaining). Full traceback: %s" + % (next_delay, max_retries - n, traceback.format_exc()) + ) logger.debug(msg) time.sleep(1e-3 * next_delay) class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): - def __array__(self, dtype=None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return np.asarray(self[key], dtype=dtype) class AbstractDataStore(Mapping): - def __iter__(self): return iter(self.variables) @@ -117,32 +116,42 @@ def load(self): This function will be called anytime variables or attributes are requested, so care should be taken to make sure its fast. """ - variables = FrozenOrderedDict((_decode_variable_name(k), v) - for k, v in self.get_variables().items()) + variables = FrozenOrderedDict( + (_decode_variable_name(k), v) for k, v in self.get_variables().items() + ) attributes = FrozenOrderedDict(self.get_attrs()) return variables, attributes @property def variables(self): # pragma: no cover - warnings.warn('The ``variables`` property has been deprecated and ' - 'will be removed in xarray v0.11.', - FutureWarning, stacklevel=2) + warnings.warn( + "The ``variables`` property has been deprecated and " + "will be removed in xarray v0.11.", + FutureWarning, + stacklevel=2, + ) variables, _ = self.load() return variables @property def attrs(self): # pragma: no cover - warnings.warn('The ``attrs`` property has been deprecated and ' - 'will be removed in xarray v0.11.', - FutureWarning, stacklevel=2) + warnings.warn( + "The ``attrs`` property has been deprecated and " + "will be removed in xarray v0.11.", + FutureWarning, + stacklevel=2, + ) _, attrs = self.load() return attrs @property def dimensions(self): # pragma: no cover - warnings.warn('The ``dimensions`` property has been deprecated and ' - 'will be removed in xarray v0.11.', - FutureWarning, stacklevel=2) + warnings.warn( + "The ``dimensions`` property has been deprecated and " + "will be removed in xarray v0.11.", + FutureWarning, + stacklevel=2, + ) return self.get_dimensions() def close(self): @@ -176,13 +185,19 @@ def add(self, source, target, region=None): def sync(self, compute=True): if self.sources: import dask.array as da + # TODO: consider wrapping targets with dask.delayed, if this makes # for any discernable difference in perforance, e.g., # targets = [dask.delayed(t) for t in self.targets] - delayed_store = da.store(self.sources, self.targets, - lock=self.lock, compute=compute, - flush=True, regions=self.regions) + delayed_store = da.store( + self.sources, + self.targets, + lock=self.lock, + compute=compute, + flush=True, + regions=self.regions, + ) self.sources = [] self.targets = [] self.regions = [] @@ -190,7 +205,6 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def encode(self, variables, attributes): """ Encode the variables and attributes in this store @@ -208,10 +222,12 @@ def encode(self, variables, attributes): attributes : dict-like """ - variables = OrderedDict([(k, self.encode_variable(v)) - for k, v in variables.items()]) - attributes = OrderedDict([(k, self.encode_attribute(v)) - for k, v in attributes.items()]) + variables = OrderedDict( + [(k, self.encode_variable(v)) for k, v in variables.items()] + ) + attributes = OrderedDict( + [(k, self.encode_attribute(v)) for k, v in attributes.items()] + ) return variables, attributes def encode_variable(self, v): @@ -240,8 +256,14 @@ def store_dataset(self, dataset): """ self.store(dataset, dataset.attrs) - def store(self, variables, attributes, check_encoding_set=frozenset(), - writer=None, unlimited_dims=None): + def store( + self, + variables, + attributes, + check_encoding_set=frozenset(), + writer=None, + unlimited_dims=None, + ): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -269,8 +291,9 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, writer, - unlimited_dims=unlimited_dims) + self.set_variables( + variables, check_encoding_set, writer, unlimited_dims=unlimited_dims + ) def set_attributes(self, attributes): """ @@ -285,8 +308,7 @@ def set_attributes(self, attributes): for k, v in attributes.items(): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, writer, - unlimited_dims=None): + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data store. @@ -308,7 +330,8 @@ def set_variables(self, variables, check_encoding_set, writer, name = _encode_variable_name(vn) check = vn in check_encoding_set target, source = self.prepare_variable( - name, v, check, unlimited_dims=unlimited_dims) + name, v, check, unlimited_dims=unlimited_dims + ) writer.add(source, target) @@ -340,20 +363,22 @@ def set_dimensions(self, variables, unlimited_dims=None): if dim in existing_dims and length != existing_dims[dim]: raise ValueError( "Unable to update size for existing dimension" - "%r (%d != %d)" % (dim, length, existing_dims[dim])) + "%r (%d != %d)" % (dim, length, existing_dims[dim]) + ) elif dim not in existing_dims: is_unlimited = dim in unlimited_dims self.set_dimension(dim, length, is_unlimited) class WritableCFDataStore(AbstractWritableDataStore): - def encode(self, variables, attributes): # All NetCDF files get CF encoded by default, without this attempting # to write times, for example, would fail. variables, attributes = cf_encoder(variables, attributes) - variables = OrderedDict([(k, self.encode_variable(v)) - for k, v in variables.items()]) - attributes = OrderedDict([(k, self.encode_attribute(v)) - for k, v in attributes.items()]) + variables = OrderedDict( + [(k, self.encode_variable(v)) for k, v in variables.items()] + ) + attributes = OrderedDict( + [(k, self.encode_attribute(v)) for k, v in attributes.items()] + ) return variables, attributes diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 525b54db5da..dfd38ff9f48 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -9,14 +9,13 @@ from .lru_cache import LRUCache # Global cache for storing open files. -FILE_CACHE = LRUCache( - OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) -assert FILE_CACHE.maxsize, 'file cache must be at least size one' +FILE_CACHE = LRUCache(OPTIONS["file_cache_maxsize"], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, "file cache must be at least size one" REF_COUNTS = {} # type: Dict[Any, int] -_DEFAULT_MODE = utils.ReprObject('') +_DEFAULT_MODE = utils.ReprObject("") class FileManager: @@ -73,8 +72,16 @@ class CachingFileManager(FileManager): """ - def __init__(self, opener, *args, mode=_DEFAULT_MODE, kwargs=None, - lock=None, cache=None, ref_counts=None): + def __init__( + self, + opener, + *args, + mode=_DEFAULT_MODE, + kwargs=None, + lock=None, + cache=None, + ref_counts=None + ): """Initialize a FileManager. The cache and ref_counts arguments exist solely to facilitate @@ -136,10 +143,12 @@ def __init__(self, opener, *args, mode=_DEFAULT_MODE, kwargs=None, def _make_key(self): """Make a key for caching files in the LRU cache.""" - value = (self._opener, - self._args, - 'a' if self._mode == 'w' else self._mode, - tuple(sorted(self._kwargs.items()))) + value = ( + self._opener, + self._args, + "a" if self._mode == "w" else self._mode, + tuple(sorted(self._kwargs.items())), + ) return _HashedSequence(value) @contextlib.contextmanager @@ -188,11 +197,11 @@ def _acquire_with_cache_info(self, needs_lock=True): kwargs = self._kwargs if self._mode is not _DEFAULT_MODE: kwargs = kwargs.copy() - kwargs['mode'] = self._mode + kwargs["mode"] = self._mode file = self._opener(*self._args, **kwargs) - if self._mode == 'w': + if self._mode == "w": # ensure file doesn't get overriden when opened again - self._mode = 'a' + self._mode = "a" self._cache[self._key] = file return file, False else: @@ -232,11 +241,13 @@ def __del__(self): finally: self._lock.release() - if OPTIONS['warn_for_unclosed_files']: + if OPTIONS["warn_for_unclosed_files"]: warnings.warn( - 'deallocating {}, but file is not already closed. ' - 'This may indicate a bug.' - .format(self), RuntimeWarning, stacklevel=2) + "deallocating {}, but file is not already closed. " + "This may indicate a bug.".format(self), + RuntimeWarning, + stacklevel=2, + ) def __getstate__(self): """State for pickling.""" @@ -251,15 +262,17 @@ def __setstate__(self, state): self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) def __repr__(self): - args_string = ', '.join(map(repr, self._args)) + args_string = ", ".join(map(repr, self._args)) if self._mode is not _DEFAULT_MODE: - args_string += ', mode={!r}'.format(self._mode) - return '{}({!r}, {}, kwargs={})'.format( - type(self).__name__, self._opener, args_string, self._kwargs) + args_string += ", mode={!r}".format(self._mode) + return "{}({!r}, {}, kwargs={})".format( + type(self).__name__, self._opener, args_string, self._kwargs + ) class _RefCounter: """Class for keeping track of reference counts.""" + def __init__(self, counts): self._counts = counts self._lock = threading.Lock() @@ -299,6 +312,7 @@ def __hash__(self): class DummyFileManager(FileManager): """FileManager that simply wraps an open file in the FileManager interface. """ + def __init__(self, value): self._value = value diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index b560e479f8f..9a111184ece 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -10,12 +10,15 @@ from .file_manager import CachingFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( - BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_require_group) + BaseNetCDF4Array, + _encode_nc4_variable, + _extract_nc4_variable_encoding, + _get_datatype, + _nc4_require_group, +) class H5NetCDFArrayWrapper(BaseNetCDF4Array): - def get_array(self, needs_lock=True): ds = self.datastore._acquire(needs_lock) variable = ds.variables[self.variable_name] @@ -23,8 +26,8 @@ def get_array(self, needs_lock=True): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, - self._getitem) + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) def _getitem(self, key): # h5py requires using lists for fancy indexing: @@ -37,7 +40,7 @@ def _getitem(self, key): def maybe_decode_bytes(txt): if isinstance(txt, bytes): - return txt.decode('utf-8') + return txt.decode("utf-8") else: return txt @@ -48,15 +51,15 @@ def _read_attributes(h5netcdf_var): # bytes attributes to strings attrs = OrderedDict() for k, v in h5netcdf_var.attrs.items(): - if k not in ['_FillValue', 'missing_value']: + if k not in ["_FillValue", "missing_value"]: v = maybe_decode_bytes(v) attrs[k] = v return attrs _extract_h5nc_encoding = functools.partial( - _extract_nc4_variable_encoding, - lsd_okay=False, h5py_okay=True, backend='h5netcdf') + _extract_nc4_variable_encoding, lsd_okay=False, h5py_okay=True, backend="h5netcdf" +) def _h5netcdf_create_group(dataset, name): @@ -67,18 +70,18 @@ class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ - def __init__(self, filename, mode='r', format=None, group=None, - lock=None, autoclose=False): + def __init__( + self, filename, mode="r", format=None, group=None, lock=None, autoclose=False + ): import h5netcdf - if format not in [None, 'NETCDF4']: - raise ValueError('invalid format for h5netcdf backend') + if format not in [None, "NETCDF4"]: + raise ValueError("invalid format for h5netcdf backend") - self._manager = CachingFileManager( - h5netcdf.File, filename, mode=mode) + self._manager = CachingFileManager(h5netcdf.File, filename, mode=mode) if lock is None: - if mode == 'r': + if mode == "r": lock = HDF5_LOCK else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) @@ -92,8 +95,9 @@ def __init__(self, filename, mode='r', format=None, group=None, def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: - ds = _nc4_require_group(root, self._group, self._mode, - create_group=_h5netcdf_create_group) + ds = _nc4_require_group( + root, self._group, self._mode, create_group=_h5netcdf_create_group + ) return ds @property @@ -104,43 +108,43 @@ def open_store_variable(self, name, var): import h5py dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - H5NetCDFArrayWrapper(name, self)) + data = indexing.LazilyOuterIndexedArray(H5NetCDFArrayWrapper(name, self)) attrs = _read_attributes(var) # netCDF4 specific encoding encoding = { - 'chunksizes': var.chunks, - 'fletcher32': var.fletcher32, - 'shuffle': var.shuffle, + "chunksizes": var.chunks, + "fletcher32": var.fletcher32, + "shuffle": var.shuffle, } # Convert h5py-style compression options to NetCDF4-Python # style, if possible - if var.compression == 'gzip': - encoding['zlib'] = True - encoding['complevel'] = var.compression_opts + if var.compression == "gzip": + encoding["zlib"] = True + encoding["complevel"] = var.compression_opts elif var.compression is not None: - encoding['compression'] = var.compression - encoding['compression_opts'] = var.compression_opts + encoding["compression"] = var.compression + encoding["compression_opts"] = var.compression_opts # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape + encoding["source"] = self._filename + encoding["original_shape"] = var.shape vlen_dtype = h5py.check_dtype(vlen=var.dtype) if vlen_dtype is str: - encoding['dtype'] = str + encoding["dtype"] = str elif vlen_dtype is not None: # pragma: no cover # xarray doesn't support writing arbitrary vlen dtypes yet. pass else: - encoding['dtype'] = var.dtype + encoding["dtype"] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) def get_attrs(self): return FrozenOrderedDict(_read_attributes(self.ds)) @@ -150,8 +154,9 @@ def get_dimensions(self): def get_encoding(self): encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding["unlimited_dims"] = { + k for k, v in self.ds.dimensions.items() if v is None + } return encoding def set_dimension(self, name, length, is_unlimited=False): @@ -167,61 +172,71 @@ def set_attribute(self, key, value): def encode_variable(self, variable): return _encode_nc4_variable(variable) - def prepare_variable(self, name, variable, check_encoding=False, - unlimited_dims=None): + def prepare_variable( + self, name, variable, check_encoding=False, unlimited_dims=None + ): import h5py attrs = variable.attrs.copy() - dtype = _get_datatype( - variable, raise_on_invalid_encoding=check_encoding) + dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding) - fillvalue = attrs.pop('_FillValue', None) + fillvalue = attrs.pop("_FillValue", None) if dtype is str and fillvalue is not None: raise NotImplementedError( - 'h5netcdf does not yet support setting a fill value for ' - 'variable-length strings ' - '(https://github.com/shoyer/h5netcdf/issues/37). ' + "h5netcdf does not yet support setting a fill value for " + "variable-length strings " + "(https://github.com/shoyer/h5netcdf/issues/37). " "Either remove '_FillValue' from encoding on variable %r " "or set {'dtype': 'S1'} in encoding to use the fixed width " - 'NC_CHAR type.' % name) + "NC_CHAR type." % name + ) if dtype is str: dtype = h5py.special_dtype(vlen=str) - encoding = _extract_h5nc_encoding(variable, - raise_on_invalid=check_encoding) + encoding = _extract_h5nc_encoding(variable, raise_on_invalid=check_encoding) kwargs = {} # Convert from NetCDF4-Python style compression settings to h5py style # If both styles are used together, h5py takes precedence # If set_encoding=True, raise ValueError in case of mismatch - if encoding.pop('zlib', False): - if (check_encoding and encoding.get('compression') - not in (None, 'gzip')): + if encoding.pop("zlib", False): + if check_encoding and encoding.get("compression") not in (None, "gzip"): raise ValueError("'zlib' and 'compression' encodings mismatch") - encoding.setdefault('compression', 'gzip') - - if (check_encoding and - 'complevel' in encoding and 'compression_opts' in encoding and - encoding['complevel'] != encoding['compression_opts']): - raise ValueError("'complevel' and 'compression_opts' encodings " - "mismatch") - complevel = encoding.pop('complevel', 0) + encoding.setdefault("compression", "gzip") + + if ( + check_encoding + and "complevel" in encoding + and "compression_opts" in encoding + and encoding["complevel"] != encoding["compression_opts"] + ): + raise ValueError("'complevel' and 'compression_opts' encodings " "mismatch") + complevel = encoding.pop("complevel", 0) if complevel != 0: - encoding.setdefault('compression_opts', complevel) + encoding.setdefault("compression_opts", complevel) - encoding['chunks'] = encoding.pop('chunksizes', None) + encoding["chunks"] = encoding.pop("chunksizes", None) # Do not apply compression, filters or chunking to scalars. if variable.shape: - for key in ['compression', 'compression_opts', 'shuffle', - 'chunks', 'fletcher32']: + for key in [ + "compression", + "compression_opts", + "shuffle", + "chunks", + "fletcher32", + ]: if key in encoding: kwargs[key] = encoding[key] if name not in self.ds: nc4_var = self.ds.create_variable( - name, dtype=dtype, dimensions=variable.dims, - fillvalue=fillvalue, **kwargs) + name, + dtype=dtype, + dimensions=variable.dims, + fillvalue=fillvalue, + **kwargs + ) else: nc4_var = self.ds[name] diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index bb63186ce3a..1c5edc215fc 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -21,7 +21,9 @@ NETCDFC_LOCK = SerializableLock() -_FILE_LOCKS = weakref.WeakValueDictionary() # type: MutableMapping[Any, threading.Lock] # noqa +_FILE_LOCKS = ( + weakref.WeakValueDictionary() +) # type: MutableMapping[Any, threading.Lock] # noqa def _get_threaded_lock(key): @@ -41,9 +43,9 @@ def _get_multiprocessing_lock(key): _LOCK_MAKERS = { None: _get_threaded_lock, - 'threaded': _get_threaded_lock, - 'multiprocessing': _get_multiprocessing_lock, - 'distributed': DistributedLock, + "threaded": _get_threaded_lock, + "multiprocessing": _get_multiprocessing_lock, + "distributed": DistributedLock, } @@ -74,27 +76,31 @@ def _get_scheduler(get=None, collection=None): try: # dask 0.18.1 and later from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) except ImportError: try: from dask.utils import effective_get + actual_get = effective_get(get, collection) except ImportError: return None try: from dask.distributed import Client + if isinstance(actual_get.__self__, Client): - return 'distributed' + return "distributed" except (ImportError, AttributeError): try: import dask.multiprocessing + if actual_get == dask.multiprocessing.get: - return 'multiprocessing' + return "multiprocessing" else: - return 'threaded' + return "threaded" except ImportError: - return 'threaded' + return "threaded" def get_write_lock(key): diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 4be6efea7c0..69a81527d3f 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -17,6 +17,7 @@ class LRUCache(collections.abc.MutableMapping): The ``maxsize`` property can be used to view or adjust the capacity of the cache, e.g., ``cache.maxsize = new_size``. """ + def __init__(self, maxsize, on_evict=None): """ Parameters @@ -28,9 +29,9 @@ def __init__(self, maxsize, on_evict=None): evicted. """ if not isinstance(maxsize, int): - raise TypeError('maxsize must be an integer') + raise TypeError("maxsize must be an integer") if maxsize < 0: - raise ValueError('maxsize must be non-negative') + raise ValueError("maxsize must be non-negative") self._maxsize = maxsize self._on_evict = on_evict self._cache = collections.OrderedDict() @@ -84,7 +85,7 @@ def maxsize(self): def maxsize(self, size): """Resize the cache, evicting the oldest items if necessary.""" if size < 0: - raise ValueError('maxsize must be non-negative') + raise ValueError("maxsize must be non-negative") with self._lock: self._enforce_size_limit(size) self._maxsize = size diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a93fba65d18..9866a2fe344 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -12,18 +12,18 @@ from ..core import indexing from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - BackendArray, WritableCFDataStore, find_root_and_group, robust_getitem) + BackendArray, + WritableCFDataStore, + find_root_and_group, + robust_getitem, +) from .file_manager import CachingFileManager, DummyFileManager -from .locks import ( - HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock) +from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. -_endian_lookup = {'=': 'native', - '>': 'big', - '<': 'little', - '|': 'native'} +_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) @@ -42,7 +42,7 @@ def __init__(self, variable_name, datastore): # use object dtype because that's the only way in numpy to # represent variable length strings; it also prevents automatic # string concatenation via conventions.decode_cf_variable - dtype = np.dtype('O') + dtype = np.dtype("O") self.dtype = dtype def __setitem__(self, key, value): @@ -54,7 +54,6 @@ def __setitem__(self, key, value): class NetCDF4ArrayWrapper(BaseNetCDF4Array): - def get_array(self, needs_lock=True): ds = self.datastore._acquire(needs_lock) variable = ds.variables[self.variable_name] @@ -66,8 +65,8 @@ def get_array(self, needs_lock=True): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER, - self._getitem) + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + ) def _getitem(self, key): if self.datastore.is_remote: # pragma: no cover @@ -83,16 +82,20 @@ def _getitem(self, key): # Catch IndexError in netCDF4 and return a more informative # error message. This is most often called when an unsorted # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') + msg = ( + "The indexing operation you are attempting to perform " + "is not valid on netCDF4.Variable object. Try loading " + "your data into memory first by calling .load()." + ) raise IndexError(msg) return array def _encode_nc4_variable(var): - for coder in [coding.strings.EncodedStringCoder(allows_unicode=True), - coding.strings.CharacterArrayCoder()]: + for coder in [ + coding.strings.EncodedStringCoder(allows_unicode=True), + coding.strings.CharacterArrayCoder(), + ]: var = coder.encode(var) return var @@ -101,35 +104,36 @@ def _check_encoding_dtype_is_vlen_string(dtype): if dtype is not str: raise AssertionError( # pragma: no cover "unexpected dtype encoding %r. This shouldn't happen: please " - "file a bug report at github.com/pydata/xarray" % dtype) + "file a bug report at github.com/pydata/xarray" % dtype + ) -def _get_datatype(var, nc_format='NETCDF4', raise_on_invalid_encoding=False): - if nc_format == 'NETCDF4': +def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False): + if nc_format == "NETCDF4": datatype = _nc4_dtype(var) else: - if 'dtype' in var.encoding: - encoded_dtype = var.encoding['dtype'] + if "dtype" in var.encoding: + encoded_dtype = var.encoding["dtype"] _check_encoding_dtype_is_vlen_string(encoded_dtype) if raise_on_invalid_encoding: raise ValueError( - 'encoding dtype=str for vlen strings is only supported ' - 'with format=\'NETCDF4\'.') + "encoding dtype=str for vlen strings is only supported " + "with format='NETCDF4'." + ) datatype = var.dtype return datatype def _nc4_dtype(var): - if 'dtype' in var.encoding: - dtype = var.encoding.pop('dtype') + if "dtype" in var.encoding: + dtype = var.encoding.pop("dtype") _check_encoding_dtype_is_vlen_string(dtype) elif coding.strings.is_unicode_dtype(var.dtype): dtype = str - elif var.dtype.kind in ['i', 'u', 'f', 'c', 'S']: + elif var.dtype.kind in ["i", "u", "f", "c", "S"]: dtype = var.dtype else: - raise ValueError('unsupported dtype for netCDF4 variable: {}' - .format(var.dtype)) + raise ValueError("unsupported dtype for netCDF4 variable: {}".format(var.dtype)) return dtype @@ -138,32 +142,32 @@ def _netcdf4_create_group(dataset, name): def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): - if group in {None, '', '/'}: + if group in {None, "", "/"}: # use the root group return ds else: # make sure it's a string if not isinstance(group, str): - raise ValueError('group must be a string or None') + raise ValueError("group must be a string or None") # support path-like syntax - path = group.strip('/').split('/') + path = group.strip("/").split("/") for key in path: try: ds = ds.groups[key] except KeyError as e: - if mode != 'r': + if mode != "r": ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message - raise OSError('group not found: %s' % key, e) + raise OSError("group not found: %s" % key, e) return ds def _ensure_fill_value_valid(data, attributes): # work around for netCDF4/scipy issue where _FillValue has the wrong type: # https://github.com/Unidata/netcdf4-python/issues/271 - if data.dtype.kind == 'S' and '_FillValue' in attributes: - attributes['_FillValue'] = np.string_(attributes['_FillValue']) + if data.dtype.kind == "S" and "_FillValue" in attributes: + attributes["_FillValue"] = np.string_(attributes["_FillValue"]) def _force_native_endianness(var): @@ -173,57 +177,71 @@ def _force_native_endianness(var): # > big-endian # | not applicable # Below we check if the data type is not native or NA - if var.dtype.byteorder not in ['=', '|']: + if var.dtype.byteorder not in ["=", "|"]: # if endianness is specified explicitly, convert to the native type - data = var.data.astype(var.dtype.newbyteorder('=')) + data = var.data.astype(var.dtype.newbyteorder("=")) var = Variable(var.dims, data, var.attrs, var.encoding) # if endian exists, remove it from the encoding. - var.encoding.pop('endian', None) + var.encoding.pop("endian", None) # check to see if encoding has a value for endian its 'native' - if not var.encoding.get('endian', 'native') == 'native': - raise NotImplementedError("Attempt to write non-native endian type, " - "this is not supported by the netCDF4 " - "python library.") + if not var.encoding.get("endian", "native") == "native": + raise NotImplementedError( + "Attempt to write non-native endian type, " + "this is not supported by the netCDF4 " + "python library." + ) return var -def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, - lsd_okay=True, h5py_okay=False, - backend='netCDF4', unlimited_dims=None): +def _extract_nc4_variable_encoding( + variable, + raise_on_invalid=False, + lsd_okay=True, + h5py_okay=False, + backend="netCDF4", + unlimited_dims=None, +): if unlimited_dims is None: unlimited_dims = () encoding = variable.encoding.copy() - safe_to_drop = {'source', 'original_shape'} + safe_to_drop = {"source", "original_shape"} valid_encodings = { - 'zlib', 'complevel', 'fletcher32', 'contiguous', - 'chunksizes', 'shuffle', '_FillValue', 'dtype' + "zlib", + "complevel", + "fletcher32", + "contiguous", + "chunksizes", + "shuffle", + "_FillValue", + "dtype", } if lsd_okay: - valid_encodings.add('least_significant_digit') + valid_encodings.add("least_significant_digit") if h5py_okay: - valid_encodings.add('compression') - valid_encodings.add('compression_opts') + valid_encodings.add("compression") + valid_encodings.add("compression_opts") - if not raise_on_invalid and encoding.get('chunksizes') is not None: + if not raise_on_invalid and encoding.get("chunksizes") is not None: # It's possible to get encoded chunksizes larger than a dimension size # if the original file had an unlimited dimension. This is problematic # if the new file no longer has an unlimited dimension. - chunksizes = encoding['chunksizes'] + chunksizes = encoding["chunksizes"] chunks_too_big = any( c > d and dim not in unlimited_dims - for c, d, dim in zip(chunksizes, variable.shape, variable.dims)) - has_original_shape = 'original_shape' in encoding - changed_shape = (has_original_shape and - encoding.get('original_shape') != variable.shape) + for c, d, dim in zip(chunksizes, variable.shape, variable.dims) + ) + has_original_shape = "original_shape" in encoding + changed_shape = ( + has_original_shape and encoding.get("original_shape") != variable.shape + ) if chunks_too_big or changed_shape: - del encoding['chunksizes'] + del encoding["chunksizes"] var_has_unlim_dim = any(dim in unlimited_dims for dim in variable.dims) - if (not raise_on_invalid and var_has_unlim_dim - and 'contiguous' in encoding.keys()): - del encoding['contiguous'] + if not raise_on_invalid and var_has_unlim_dim and "contiguous" in encoding.keys(): + del encoding["contiguous"] for k in safe_to_drop: if k in encoding: @@ -233,8 +251,9 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, invalid = [k for k in encoding if k not in valid_encodings] if invalid: raise ValueError( - 'unexpected encoding parameters for %r backend: %r. Valid ' - 'encodings are: %r' % (backend, invalid, valid_encodings)) + "unexpected encoding parameters for %r backend: %r. Valid " + "encodings are: %r" % (backend, invalid, valid_encodings) + ) else: for k in list(encoding): if k not in valid_encodings: @@ -244,8 +263,7 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, def _is_list_of_strings(value): - if (np.asarray(value).dtype.kind in ['U', 'S'] and - np.asarray(value).size > 1): + if np.asarray(value).dtype.kind in ["U", "S"] and np.asarray(value).size > 1: return True else: return False @@ -260,9 +278,11 @@ def _set_nc_attribute(obj, key, value): # Inform users with old netCDF that does not support # NC_STRING that we can't serialize lists of strings # as attrs - msg = ('Attributes which are lists of strings are not ' - 'supported with this version of netCDF. Please ' - 'upgrade to netCDF4-python 1.2.4 or greater.') + msg = ( + "Attributes which are lists of strings are not " + "supported with this version of netCDF. Please " + "upgrade to netCDF4-python 1.2.4 or greater." + ) raise AttributeError(msg) else: obj.setncattr(key, value) @@ -274,8 +294,9 @@ class NetCDF4DataStore(WritableCFDataStore): This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, manager, group=None, mode=None, - lock=NETCDF4_PYTHON_LOCK, autoclose=False): + def __init__( + self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False + ): import netCDF4 if isinstance(manager, netCDF4.Dataset): @@ -284,8 +305,9 @@ def __init__(self, manager, group=None, mode=None, else: if not type(manager) is netCDF4.Dataset: raise ValueError( - 'must supply a root netCDF4.Dataset if the group ' - 'argument is provided') + "must supply a root netCDF4.Dataset if the group " + "argument is provided" + ) root = manager manager = DummyFileManager(root) @@ -299,42 +321,54 @@ def __init__(self, manager, group=None, mode=None, self.autoclose = autoclose @classmethod - def open(cls, filename, mode='r', format='NETCDF4', group=None, - clobber=True, diskless=False, persist=False, - lock=None, lock_maker=None, autoclose=False): + def open( + cls, + filename, + mode="r", + format="NETCDF4", + group=None, + clobber=True, + diskless=False, + persist=False, + lock=None, + lock_maker=None, + autoclose=False, + ): import netCDF4 - if (len(filename) == 88 and - LooseVersion(netCDF4.__version__) < "1.3.1"): + + if len(filename) == 88 and LooseVersion(netCDF4.__version__) < "1.3.1": warnings.warn( - 'A segmentation fault may occur when the ' - 'file path has exactly 88 characters as it does ' - 'in this case. The issue is known to occur with ' - 'version 1.2.4 of netCDF4 and can be addressed by ' - 'upgrading netCDF4 to at least version 1.3.1. ' - 'More details can be found here: ' - 'https://github.com/pydata/xarray/issues/1745') + "A segmentation fault may occur when the " + "file path has exactly 88 characters as it does " + "in this case. The issue is known to occur with " + "version 1.2.4 of netCDF4 and can be addressed by " + "upgrading netCDF4 to at least version 1.3.1. " + "More details can be found here: " + "https://github.com/pydata/xarray/issues/1745" + ) if format is None: - format = 'NETCDF4' + format = "NETCDF4" if lock is None: - if mode == 'r': + if mode == "r": if is_remote_uri(filename): lock = NETCDFC_LOCK else: lock = NETCDF4_PYTHON_LOCK else: - if format is None or format.startswith('NETCDF4'): + if format is None or format.startswith("NETCDF4"): base_lock = NETCDF4_PYTHON_LOCK else: base_lock = NETCDFC_LOCK lock = combine_locks([base_lock, get_write_lock(filename)]) - kwargs = dict(clobber=clobber, diskless=diskless, persist=persist, - format=format) + kwargs = dict( + clobber=clobber, diskless=diskless, persist=persist, format=format + ) manager = CachingFileManager( - netCDF4.Dataset, filename, mode=mode, kwargs=kwargs) - return cls(manager, group=group, mode=mode, lock=lock, - autoclose=autoclose) + netCDF4.Dataset, filename, mode=mode, kwargs=kwargs + ) + return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: @@ -347,10 +381,8 @@ def ds(self): def open_store_variable(self, name, var): dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) + data = indexing.LazilyOuterIndexedArray(NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) for k in var.ncattrs()) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later encoding = {} @@ -359,43 +391,42 @@ def open_store_variable(self, name, var): encoding.update(filters) chunking = var.chunking() if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None + if chunking == "contiguous": + encoding["contiguous"] = True + encoding["chunksizes"] = None else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) + encoding["contiguous"] = False + encoding["chunksizes"] = tuple(chunking) # TODO: figure out how to round-trip "endian-ness" without raising # warnings from netCDF4 # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') + pop_to(attributes, encoding, "least_significant_digit") # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - encoding['dtype'] = var.dtype + encoding["source"] = self._filename + encoding["original_shape"] = var.shape + encoding["dtype"] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - self.ds.variables.items()) + dsvars = FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) return dsvars def get_attrs(self): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - dims = FrozenOrderedDict((k, len(v)) - for k, v in self.ds.dimensions.items()) + dims = FrozenOrderedDict((k, len(v)) for k, v in self.ds.dimensions.items()) return dims def get_encoding(self): encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding["unlimited_dims"] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited() + } return encoding def set_dimension(self, name, length, is_unlimited=False): @@ -403,38 +434,41 @@ def set_dimension(self, name, length, is_unlimited=False): self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - if self.format != 'NETCDF4': + if self.format != "NETCDF4": value = encode_nc3_attr_value(value) _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) - if self.format == 'NETCDF4': + if self.format == "NETCDF4": variable = _encode_nc4_variable(variable) else: variable = encode_nc3_variable(variable) return variable - def prepare_variable(self, name, variable, check_encoding=False, - unlimited_dims=None): - datatype = _get_datatype(variable, self.format, - raise_on_invalid_encoding=check_encoding) + def prepare_variable( + self, name, variable, check_encoding=False, unlimited_dims=None + ): + datatype = _get_datatype( + variable, self.format, raise_on_invalid_encoding=check_encoding + ) attrs = variable.attrs.copy() - fill_value = attrs.pop('_FillValue', None) + fill_value = attrs.pop("_FillValue", None) if datatype is str and fill_value is not None: raise NotImplementedError( - 'netCDF4 does not yet support setting a fill value for ' - 'variable-length strings ' - '(https://github.com/Unidata/netcdf4-python/issues/730). ' + "netCDF4 does not yet support setting a fill value for " + "variable-length strings " + "(https://github.com/Unidata/netcdf4-python/issues/730). " "Either remove '_FillValue' from encoding on variable %r " "or set {'dtype': 'S1'} in encoding to use the fixed width " - 'NC_CHAR type.' % name) + "NC_CHAR type." % name + ) encoding = _extract_nc4_variable_encoding( - variable, raise_on_invalid=check_encoding, - unlimited_dims=unlimited_dims) + variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims + ) if name in self.ds.variables: nc4_var = self.ds.variables[name] @@ -443,16 +477,16 @@ def prepare_variable(self, name, variable, check_encoding=False, varname=name, datatype=datatype, dimensions=variable.dims, - zlib=encoding.get('zlib', False), - complevel=encoding.get('complevel', 4), - shuffle=encoding.get('shuffle', True), - fletcher32=encoding.get('fletcher32', False), - contiguous=encoding.get('contiguous', False), - chunksizes=encoding.get('chunksizes'), - endian='native', - least_significant_digit=encoding.get( - 'least_significant_digit'), - fill_value=fill_value) + zlib=encoding.get("zlib", False), + complevel=encoding.get("complevel", 4), + shuffle=encoding.get("shuffle", True), + fletcher32=encoding.get("fletcher32", False), + contiguous=encoding.get("contiguous", False), + chunksizes=encoding.get("chunksizes"), + endian="native", + least_significant_digit=encoding.get("least_significant_digit"), + fill_value=fill_value, + ) for k, v in attrs.items(): # set attributes one-by-one since netCDF4<1.0.10 can't handle diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index 4985e51f689..f09af0d9fb4 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -12,16 +12,26 @@ # The following are reserved names in CDL and may not be used as names of # variables, dimension, attributes _reserved_names = { - 'byte', 'char', 'short', 'ushort', 'int', 'uint', 'int64', 'uint64', - 'float' 'real', 'double', 'bool', 'string' + "byte", + "char", + "short", + "ushort", + "int", + "uint", + "int64", + "uint64", + "float" "real", + "double", + "bool", + "string", } # These data-types aren't supported by netCDF3, so they are automatically # coerced instead as indicated by the "coerce_nc3_dtype" function -_nc3_dtype_coercions = {'int64': 'int32', 'bool': 'int8'} +_nc3_dtype_coercions = {"int64": "int32", "bool": "int8"} # encode all strings as UTF-8 -STRING_ENCODING = 'utf-8' +STRING_ENCODING = "utf-8" def coerce_nc3_dtype(arr): @@ -40,8 +50,9 @@ def coerce_nc3_dtype(arr): # TODO: raise a warning whenever casting the data-type instead? cast_arr = arr.astype(new_dtype) if not (cast_arr == arr).all(): - raise ValueError('could not safely cast array from dtype %s to %s' - % (dtype, new_dtype)) + raise ValueError( + "could not safely cast array from dtype %s to %s" % (dtype, new_dtype) + ) arr = cast_arr return arr @@ -59,13 +70,14 @@ def encode_nc3_attr_value(value): def encode_nc3_attrs(attrs): - return OrderedDict([(k, encode_nc3_attr_value(v)) - for k, v in attrs.items()]) + return OrderedDict([(k, encode_nc3_attr_value(v)) for k, v in attrs.items()]) def encode_nc3_variable(var): - for coder in [coding.strings.EncodedStringCoder(allows_unicode=False), - coding.strings.CharacterArrayCoder()]: + for coder in [ + coding.strings.EncodedStringCoder(allows_unicode=False), + coding.strings.CharacterArrayCoder(), + ]: var = coder.encode(var) data = coerce_nc3_dtype(var.data) attrs = encode_nc3_attrs(var.attrs) @@ -78,7 +90,7 @@ def _isalnumMUTF8(c): Input is not checked! """ - return c.isalnum() or (len(c.encode('utf-8')) > 1) + return c.isalnum() or (len(c.encode("utf-8")) > 1) def is_valid_nc3_name(s): @@ -101,12 +113,14 @@ def is_valid_nc3_name(s): if not isinstance(s, str): return False if not isinstance(s, str): - s = s.decode('utf-8') - num_bytes = len(s.encode('utf-8')) - return ((unicodedata.normalize('NFC', s) == s) and - (s not in _reserved_names) and - (num_bytes >= 0) and - ('/' not in s) and - (s[-1] != ' ') and - (_isalnumMUTF8(s[0]) or (s[0] == '_')) and - all(_isalnumMUTF8(c) or c in _specialchars for c in s)) + s = s.decode("utf-8") + num_bytes = len(s.encode("utf-8")) + return ( + (unicodedata.normalize("NFC", s) == s) + and (s not in _reserved_names) + and (num_bytes >= 0) + and ("/" not in s) + and (s[-1] != " ") + and (_isalnumMUTF8(s[0]) or (s[0] == "_")) + and all(_isalnumMUTF8(c) or c in _specialchars for c in s) + ) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 34a61ae8108..1fcb0ab9b3a 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -14,7 +14,6 @@ class PncArrayWrapper(BackendArray): - def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name @@ -28,8 +27,8 @@ def get_array(self, needs_lock=True): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, - self._getitem) + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) def _getitem(self, key): with self.datastore.lock: @@ -40,14 +39,15 @@ def _getitem(self, key): class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ + @classmethod def open(cls, filename, lock=None, mode=None, **format_kwargs): from PseudoNetCDF import pncopen - keywords = {'kwargs': format_kwargs} + keywords = {"kwargs": format_kwargs} # only include mode if explicitly passed if mode is not None: - keywords['mode'] = mode + keywords["mode"] = mode if lock is None: lock = PNETCDF_LOCK @@ -64,15 +64,14 @@ def ds(self): return self._manager.acquire() def open_store_variable(self, name, var): - data = indexing.LazilyOuterIndexedArray( - PncArrayWrapper(name, self) - ) + data = indexing.LazilyOuterIndexedArray(PncArrayWrapper(name, self)) attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) return Variable(var.dimensions, data, attrs) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) def get_attrs(self): return Frozen({k: getattr(self.ds, k) for k in self.ds.ncattrs()}) @@ -82,9 +81,8 @@ def get_dimensions(self): def get_encoding(self): return { - 'unlimited_dims': { - k for k in self.ds.dimensions - if self.ds.dimensions[k].isunlimited() + "unlimited_dims": { + k for k in self.ds.dimensions if self.ds.dimensions[k].isunlimited() } } diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index d3901f0f763..b0b39144e19 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -21,16 +21,16 @@ def dtype(self): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem + ) def _getitem(self, key): # pull the data from the array attribute if possible, to avoid # downloading coordinate data twice - array = getattr(self.array, 'array', self.array) + array = getattr(self.array, "array", self.array) result = robust_getitem(array, key, catch=ValueError) # in some cases, pydap doesn't squeeze axes automatically like numpy - axis = tuple(n for n, k in enumerate(key) - if isinstance(k, integer_types)) + axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types)) if result.ndim + len(axis) != array.ndim and len(axis) > 0: result = np.squeeze(result, axis) @@ -40,15 +40,19 @@ def _getitem(self, key): def _fix_attributes(attributes): attributes = dict(attributes) for k in list(attributes): - if k.lower() == 'global' or k.lower().endswith('_global'): + if k.lower() == "global" or k.lower().endswith("_global"): # move global attributes to the top level, like the netcdf-C # DAP client attributes.update(attributes.pop(k)) elif is_dict_like(attributes[k]): # Make Hierarchical attributes to a single level with a # dot-separated key - attributes.update({'{}.{}'.format(k, k_child): v_child for - k_child, v_child in attributes.pop(k).items()}) + attributes.update( + { + "{}.{}".format(k, k_child): v_child + for k_child, v_child in attributes.pop(k).items() + } + ) return attributes @@ -70,17 +74,18 @@ def __init__(self, ds): @classmethod def open(cls, url, session=None): import pydap.client + ds = pydap.client.open_url(url, session=session) return cls(ds) def open_store_variable(self, var): data = indexing.LazilyOuterIndexedArray(PydapArrayWrapper(var)) - return Variable(var.dimensions, data, - _fix_attributes(var.attributes)) + return Variable(var.dimensions, data, _fix_attributes(var.attributes)) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(self.ds[k])) - for k in self.ds.keys()) + return FrozenOrderedDict( + (k, self.open_store_variable(self.ds[k])) for k in self.ds.keys() + ) def get_attrs(self): return Frozen(_fix_attributes(self.ds.attributes)) diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 9c3946f657d..abba45cfecd 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -5,8 +5,7 @@ from ..core.utils import Frozen, FrozenOrderedDict from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager -from .locks import ( - HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock) +from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. @@ -15,7 +14,6 @@ class NioArrayWrapper(BackendArray): - def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name @@ -29,7 +27,8 @@ def get_array(self, needs_lock=True): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem + ) def _getitem(self, key): with self.datastore.lock: @@ -45,16 +44,18 @@ class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', lock=None, **kwargs): + def __init__(self, filename, mode="r", lock=None, **kwargs): import Nio + if lock is None: lock = PYNIO_LOCK self.lock = ensure_lock(lock) self._manager = CachingFileManager( - Nio.open_file, filename, lock=lock, mode=mode, kwargs=kwargs) + Nio.open_file, filename, lock=lock, mode=mode, kwargs=kwargs + ) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. - self.ds.set_option('MaskedArrayMode', 'MaskedNever') + self.ds.set_option("MaskedArrayMode", "MaskedNever") @property def ds(self): @@ -65,8 +66,9 @@ def open_store_variable(self, name, var): return Variable(var.dimensions, data, var.attributes) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) def get_attrs(self): return Frozen(self.ds.attributes) @@ -76,10 +78,7 @@ def get_dimensions(self): def get_encoding(self): return { - 'unlimited_dims': { - k for k in self.ds.dimensions - if self.ds.unlimited(k) - } + "unlimited_dims": {k for k in self.ds.dimensions if self.ds.unlimited(k)} } def close(self): diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index cc533cc3ad0..1d832d4f671 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -15,9 +15,11 @@ # TODO: should this be GDAL_LOCK instead? RASTERIO_LOCK = SerializableLock() -_ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' - 'valid on rasterio files. Try to load your data with ds.load()' - 'first.') +_ERROR_MSG = ( + "The kind of indexing operation you are trying to do is not " + "valid on rasterio files. Try to load your data with ds.load()" + "first." +) class RasterioArrayWrapper(BackendArray): @@ -25,6 +27,7 @@ class RasterioArrayWrapper(BackendArray): def __init__(self, manager, lock, vrt_params=None): from rasterio.vrt import WarpedVRT + self.manager = manager self.lock = lock @@ -37,7 +40,7 @@ def __init__(self, manager, lock, vrt_params=None): dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): - raise ValueError('All bands should have the same dtype') + raise ValueError("All bands should have the same dtype") self._dtype = np.dtype(dtypes[0]) @property @@ -66,7 +69,7 @@ def _get_indexer(self, key): -------- indexing.decompose_indexer """ - assert len(key) == 3, 'rasterio datasets should always be 3D' + assert len(key) == 3, "rasterio datasets should always be 3D" # bands cannot be windowed but they can be listed band_key = key[0] @@ -91,7 +94,7 @@ def _get_indexer(self, key): elif is_scalar(k): # windowed operations will always return an array # we will have to squeeze it later - squeeze_axis.append(- (2 - i)) + squeeze_axis.append(-(2 - i)) start = k stop = k + 1 else: @@ -107,12 +110,12 @@ def _get_indexer(self, key): def _getitem(self, key): from rasterio.vrt import WarpedVRT + band_key, window, squeeze_axis, np_inds = self._get_indexer(key) if not band_key or any(start == stop for (start, stop) in window): # no need to do IO - shape = (len(band_key),) + tuple( - stop - start for (start, stop) in window) + shape = (len(band_key),) + tuple(stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: with self.lock: @@ -127,7 +130,8 @@ def _getitem(self, key): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + ) def _parse_envi(meta): @@ -150,19 +154,17 @@ def _parse_envi(meta): """ def parsevec(s): - return np.fromstring(s.strip('{}'), dtype='float', sep=',') + return np.fromstring(s.strip("{}"), dtype="float", sep=",") def default(s): - return s.strip('{}') + return s.strip("{}") - parse = {'wavelength': parsevec, - 'fwhm': parsevec} + parse = {"wavelength": parsevec, "fwhm": parsevec} parsed_meta = {k: parse.get(k, default)(v) for k, v in meta.items()} return parsed_meta -def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, - lock=None): +def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=None): """Open a file with rasterio (experimental). This should work with any file that rasterio can open (most often: @@ -213,26 +215,29 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, """ import rasterio from rasterio.vrt import WarpedVRT + vrt_params = None if isinstance(filename, rasterio.io.DatasetReader): filename = filename.name elif isinstance(filename, rasterio.vrt.WarpedVRT): vrt = filename filename = vrt.src_dataset.name - vrt_params = dict(crs=vrt.crs.to_string(), - resampling=vrt.resampling, - src_nodata=vrt.src_nodata, - dst_nodata=vrt.dst_nodata, - tolerance=vrt.tolerance, - transform=vrt.transform, - width=vrt.width, - height=vrt.height, - warp_extras=vrt.warp_extras) + vrt_params = dict( + crs=vrt.crs.to_string(), + resampling=vrt.resampling, + src_nodata=vrt.src_nodata, + dst_nodata=vrt.dst_nodata, + tolerance=vrt.tolerance, + transform=vrt.transform, + width=vrt.width, + height=vrt.height, + warp_extras=vrt.warp_extras, + ) if lock is None: lock = RASTERIO_LOCK - manager = CachingFileManager(rasterio.open, filename, lock=lock, mode='r') + manager = CachingFileManager(rasterio.open, filename, lock=lock, mode="r") riods = manager.acquire() if vrt_params is not None: riods = WarpedVRT(riods, **vrt_params) @@ -244,11 +249,11 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # Get bands if riods.count < 1: - raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.indexes) + raise ValueError("Unknown dims") + coords["band"] = np.asarray(riods.indexes) # Get coordinates - if LooseVersion(rasterio.__version__) < '1.0': + if LooseVersion(rasterio.__version__) < "1.0": transform = riods.affine else: transform = riods.transform @@ -260,8 +265,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform - coords['y'] = y - coords['x'] = x + coords["y"] = y + coords["x"] = x else: # 2d coordinates parse = False if (parse_coordinates is None) else parse_coordinates @@ -271,7 +276,9 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, "rectilinear: xarray won't parse the coordinates " "in this case. Set `parse_coordinates=False` to " "suppress this warning.", - RuntimeWarning, stacklevel=3) + RuntimeWarning, + stacklevel=3, + ) # Attributes attrs = dict() @@ -280,42 +287,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # For serialization store as tuple of 6 floats, the last row being # always (0, 0, 1) per definition (see # https://github.com/sgillies/affine) - attrs['transform'] = tuple(transform)[:6] - if hasattr(riods, 'crs') and riods.crs: + attrs["transform"] = tuple(transform)[:6] + if hasattr(riods, "crs") and riods.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself try: - attrs['crs'] = riods.crs.to_proj4() + attrs["crs"] = riods.crs.to_proj4() except AttributeError: - attrs['crs'] = riods.crs.to_string() - if hasattr(riods, 'res'): + attrs["crs"] = riods.crs.to_string() + if hasattr(riods, "res"): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.res - if hasattr(riods, 'is_tiled'): + attrs["res"] = riods.res + if hasattr(riods, "is_tiled"): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.is_tiled) - if hasattr(riods, 'nodatavals'): + attrs["is_tiled"] = np.uint8(riods.is_tiled) + if hasattr(riods, "nodatavals"): # The nodata values for the raster bands - attrs['nodatavals'] = tuple( - np.nan if nodataval is None else nodataval - for nodataval in riods.nodatavals) - if hasattr(riods, 'scales'): + attrs["nodatavals"] = tuple( + np.nan if nodataval is None else nodataval for nodataval in riods.nodatavals + ) + if hasattr(riods, "scales"): # The scale values for the raster bands - attrs['scales'] = riods.scales - if hasattr(riods, 'offsets'): + attrs["scales"] = riods.scales + if hasattr(riods, "offsets"): # The offset values for the raster bands - attrs['offsets'] = riods.offsets - if hasattr(riods, 'descriptions') and any(riods.descriptions): + attrs["offsets"] = riods.offsets + if hasattr(riods, "descriptions") and any(riods.descriptions): # Descriptions for each dataset band - attrs['descriptions'] = riods.descriptions - if hasattr(riods, 'units') and any(riods.units): + attrs["descriptions"] = riods.descriptions + if hasattr(riods, "units") and any(riods.units): # A list of units string for each dataset band - attrs['units'] = riods.units + attrs["units"] = riods.units # Parse extra metadata from tags, if supported - parsers = {'ENVI': _parse_envi} + parsers = {"ENVI": _parse_envi} driver = riods.driver if driver in parsers: @@ -324,25 +331,25 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise - if (isinstance(v, (list, np.ndarray)) - and len(v) == riods.count): - coords[k] = ('band', np.asarray(v)) + if isinstance(v, (list, np.ndarray)) and len(v) == riods.count: + coords[k] = ("band", np.asarray(v)) else: attrs[k] = v data = indexing.LazilyOuterIndexedArray( - RasterioArrayWrapper(manager, lock, vrt_params)) + RasterioArrayWrapper(manager, lock, vrt_params) + ) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) if cache and chunks is None: data = indexing.MemoryCachedArray(data) - result = DataArray(data=data, dims=('band', 'y', 'x'), - coords=coords, attrs=attrs) + result = DataArray(data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs) if chunks is not None: from dask.base import tokenize + # augment the token with the file modification time try: mtime = os.path.getmtime(filename) @@ -350,7 +357,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # the filename is probably an s3 bucket rather than a regular file mtime = None token = tokenize(filename, mtime, chunks) - name_prefix = 'open_rasterio-%s' % token + name_prefix = "open_rasterio-%s" % token result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 1111f30c139..c4f9666f0c1 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -11,32 +11,30 @@ from .common import BackendArray, WritableCFDataStore from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock -from .netcdf3 import ( - encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) +from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name def _decode_string(s): if isinstance(s, bytes): - return s.decode('utf-8', 'replace') + return s.decode("utf-8", "replace") return s def _decode_attrs(d): # don't decode _FillValue from bytes -> unicode, because we want to ensure # that its type matches the data exactly - return OrderedDict((k, v if k == '_FillValue' else _decode_string(v)) - for (k, v) in d.items()) + return OrderedDict( + (k, v if k == "_FillValue" else _decode_string(v)) for (k, v) in d.items() + ) class ScipyArrayWrapper(BackendArray): - def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name array = self.get_variable().data self.shape = array.shape - self.dtype = np.dtype(array.dtype.kind + - str(array.dtype.itemsize)) + self.dtype = np.dtype(array.dtype.kind + str(array.dtype.itemsize)) def get_variable(self, needs_lock=True): ds = self.datastore._manager.acquire(needs_lock) @@ -68,28 +66,29 @@ def _open_scipy_netcdf(filename, mode, mmap, version): import gzip # if the string ends with .gz, then gunzip and open as netcdf file - if isinstance(filename, str) and filename.endswith('.gz'): + if isinstance(filename, str) and filename.endswith(".gz"): try: - return scipy.io.netcdf_file(gzip.open(filename), mode=mode, - mmap=mmap, version=version) + return scipy.io.netcdf_file( + gzip.open(filename), mode=mode, mmap=mmap, version=version + ) except TypeError as e: # TODO: gzipped loading only works with NetCDF3 files. - if 'is not a valid NetCDF 3 file' in e.message: - raise ValueError('gzipped file loading only supports ' - 'NetCDF 3 files.') + if "is not a valid NetCDF 3 file" in e.message: + raise ValueError( + "gzipped file loading only supports " "NetCDF 3 files." + ) else: raise - if isinstance(filename, bytes) and filename.startswith(b'CDF'): + if isinstance(filename, bytes) and filename.startswith(b"CDF"): # it's a NetCDF3 bytestring filename = BytesIO(filename) try: - return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, - version=version) + return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) except TypeError as e: # netcdf3 message is obscure in this case errmsg = e.args[0] - if 'is not a valid NetCDF 3 file' in errmsg: + if "is not a valid NetCDF 3 file" in errmsg: msg = """ If this is a NetCDF4 file, you may need to install the netcdf4 library, e.g., @@ -111,44 +110,50 @@ class ScipyDataStore(WritableCFDataStore): It only supports the NetCDF3 file-format. """ - def __init__(self, filename_or_obj, mode='r', format=None, group=None, - mmap=None, lock=None): + def __init__( + self, filename_or_obj, mode="r", format=None, group=None, mmap=None, lock=None + ): import scipy import scipy.io - if (mode != 'r' and - scipy.__version__ < LooseVersion('0.13')): # pragma: no cover - warnings.warn('scipy %s detected; ' - 'the minimal recommended version is 0.13. ' - 'Older version of this library do not reliably ' - 'read and write files.' - % scipy.__version__, ImportWarning) + if mode != "r" and scipy.__version__ < LooseVersion("0.13"): # pragma: no cover + warnings.warn( + "scipy %s detected; " + "the minimal recommended version is 0.13. " + "Older version of this library do not reliably " + "read and write files." % scipy.__version__, + ImportWarning, + ) if group is not None: - raise ValueError('cannot save to a group with the ' - 'scipy.io.netcdf backend') + raise ValueError( + "cannot save to a group with the " "scipy.io.netcdf backend" + ) - if format is None or format == 'NETCDF3_64BIT': + if format is None or format == "NETCDF3_64BIT": version = 2 - elif format == 'NETCDF3_CLASSIC': + elif format == "NETCDF3_CLASSIC": version = 1 else: - raise ValueError('invalid format for scipy.io.netcdf backend: %r' - % format) + raise ValueError("invalid format for scipy.io.netcdf backend: %r" % format) - if (lock is None and mode != 'r' and - isinstance(filename_or_obj, str)): + if lock is None and mode != "r" and isinstance(filename_or_obj, str): lock = get_write_lock(filename_or_obj) self.lock = ensure_lock(lock) if isinstance(filename_or_obj, str): manager = CachingFileManager( - _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, - kwargs=dict(mmap=mmap, version=version)) + _open_scipy_netcdf, + filename_or_obj, + mode=mode, + lock=lock, + kwargs=dict(mmap=mmap, version=version), + ) else: scipy_dataset = _open_scipy_netcdf( - filename_or_obj, mode=mode, mmap=mmap, version=version) + filename_or_obj, mode=mode, mmap=mmap, version=version + ) manager = DummyFileManager(scipy_dataset) self._manager = manager @@ -158,12 +163,16 @@ def ds(self): return self._manager.acquire() def open_store_variable(self, name, var): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable( + var.dimensions, + ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes), + ) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) def get_attrs(self): return Frozen(_decode_attrs(self.ds._attributes)) @@ -173,14 +182,16 @@ def get_dimensions(self): def get_encoding(self): encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding["unlimited_dims"] = { + k for k, v in self.ds.dimensions.items() if v is None + } return encoding def set_dimension(self, name, length, is_unlimited=False): if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) + raise ValueError( + "%s does not support modifying dimensions" % type(self).__name__ + ) dim_length = length if not is_unlimited else None self.ds.createDimension(name, dim_length) @@ -197,12 +208,15 @@ def encode_variable(self, variable): variable = encode_nc3_variable(variable) return variable - def prepare_variable(self, name, variable, check_encoding=False, - unlimited_dims=None): + def prepare_variable( + self, name, variable, check_encoding=False, unlimited_dims=None + ): if check_encoding and variable.encoding: - if variable.encoding != {'_FillValue': None}: - raise ValueError('unexpected encoding for scipy backend: %r' - % list(variable.encoding)) + if variable.encoding != {"_FillValue": None}: + raise ValueError( + "unexpected encoding for scipy backend: %r" + % list(variable.encoding) + ) data = variable.data # nb. this still creates a numpy array in all memory, even though we diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index effacd8b4b7..31997d258c8 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -8,11 +8,10 @@ from ..core import indexing from ..core.pycompat import integer_types from ..core.utils import FrozenOrderedDict, HiddenKeyDict -from .common import AbstractWritableDataStore, BackendArray, \ - _encode_variable_name +from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name # need some special secret attributes to tell us the dimensions -_DIMENSION_KEY = '_ARRAY_DIMENSIONS' +_DIMENSION_KEY = "_ARRAY_DIMENSIONS" # zarr attributes have to be serializable as json @@ -48,8 +47,9 @@ def __getitem__(self, key): if isinstance(key, indexing.BasicIndexer): return array[key.tuple] elif isinstance(key, indexing.VectorizedIndexer): - return array.vindex[indexing._arrayize_vectorized_indexer( - key.tuple, self.shape).tuple] + return array.vindex[ + indexing._arrayize_vectorized_indexer(key.tuple, self.shape).tuple + ] else: assert isinstance(key, indexing.OuterIndexer) return array.oindex[key.tuple] @@ -82,12 +82,14 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): raise ValueError( "Zarr requires uniform chunk sizes except for final chunk." " Variable dask chunks %r are incompatible. Consider " - "rechunking using `chunk()`." % (var_chunks,)) + "rechunking using `chunk()`." % (var_chunks,) + ) if any((chunks[0] < chunks[-1]) for chunks in var_chunks): raise ValueError( "Final chunk of Zarr array must be the same size or smaller " "than the first. Variable Dask chunks %r are incompatible. " - "Consider rechunking using `chunk()`." % var_chunks) + "Consider rechunking using `chunk()`." % var_chunks + ) # return the first chunk for each dimension return tuple(chunk[0] for chunk in var_chunks) @@ -108,8 +110,10 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): for x in enc_chunks_tuple: if not isinstance(x, int): - raise TypeError("zarr chunks must be an int or a tuple of ints. " - "Instead found %r" % (enc_chunks_tuple,)) + raise TypeError( + "zarr chunks must be an int or a tuple of ints. " + "Instead found %r" % (enc_chunks_tuple,) + ) # if there are chunks in encoding and the variable data is a numpy array, # we use the specified chunks @@ -134,18 +138,19 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): "chunks %r. This is not implemented in xarray yet. " " Consider rechunking the data using " "`chunk()` or specifying different chunks in encoding." - % (enc_chunks_tuple, var_chunks)) + % (enc_chunks_tuple, var_chunks) + ) if dchunks[-1] > zchunk: raise ValueError( "Final chunk of Zarr array must be the same size or " "smaller than the first. The specified Zarr chunk " "encoding is %r, but %r in variable Dask chunks %r is " "incompatible. Consider rechunking using `chunk()`." - % (enc_chunks_tuple, dchunks, var_chunks)) + % (enc_chunks_tuple, dchunks, var_chunks) + ) return enc_chunks_tuple - raise AssertionError( - "We should never get here. Function logic must be wrong.") + raise AssertionError("We should never get here. Function logic must be wrong.") def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): @@ -156,9 +161,10 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): try: dimensions = zarr_obj.attrs[dimension_key] except KeyError: - raise KeyError("Zarr object is missing the attribute `%s`, which is " - "required for xarray to determine variable dimensions." - % (dimension_key)) + raise KeyError( + "Zarr object is missing the attribute `%s`, which is " + "required for xarray to determine variable dimensions." % (dimension_key) + ) attributes = HiddenKeyDict(zarr_obj.attrs, [dimension_key]) return dimensions, attributes @@ -166,21 +172,23 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): def _extract_zarr_variable_encoding(variable, raise_on_invalid=False): encoding = variable.encoding.copy() - valid_encodings = {'chunks', 'compressor', 'filters', 'cache_metadata'} + valid_encodings = {"chunks", "compressor", "filters", "cache_metadata"} if raise_on_invalid: invalid = [k for k in encoding if k not in valid_encodings] if invalid: - raise ValueError('unexpected encoding parameters for zarr ' - 'backend: %r' % invalid) + raise ValueError( + "unexpected encoding parameters for zarr " "backend: %r" % invalid + ) else: for k in list(encoding): if k not in valid_encodings: del encoding[k] - chunks = _determine_zarr_chunks(encoding.get('chunks'), variable.chunks, - variable.ndim) - encoding['chunks'] = chunks + chunks = _determine_zarr_chunks( + encoding.get("chunks"), variable.chunks, variable.ndim + ) + encoding["chunks"] = chunks return encoding @@ -224,24 +232,35 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - consolidated=False, consolidate_on_close=False): + def open_group( + cls, + store, + mode="r", + synchronizer=None, + group=None, + consolidated=False, + consolidate_on_close=False, + ): import zarr - min_zarr = '2.2' + + min_zarr = "2.2" if LooseVersion(zarr.__version__) < min_zarr: # pragma: no cover - raise NotImplementedError("Zarr version %s or greater is " - "required by xarray. See zarr " - "installation " - "http://zarr.readthedocs.io/en/stable/" - "#installation" % min_zarr) + raise NotImplementedError( + "Zarr version %s or greater is " + "required by xarray. See zarr " + "installation " + "http://zarr.readthedocs.io/en/stable/" + "#installation" % min_zarr + ) if consolidated or consolidate_on_close: - if LooseVersion( - zarr.__version__) <= '2.2.1.dev2': # pragma: no cover - raise NotImplementedError("Zarr version 2.2.1.dev2 or greater " - "is required by for consolidated " - "metadata.") + if LooseVersion(zarr.__version__) <= "2.2.1.dev2": # pragma: no cover + raise NotImplementedError( + "Zarr version 2.2.1.dev2 or greater " + "is required by for consolidated " + "metadata." + ) open_kwargs = dict(mode=mode, synchronizer=synchronizer, path=group) if consolidated: @@ -261,22 +280,24 @@ def __init__(self, zarr_group, consolidate_on_close=False): def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) - dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, - _DIMENSION_KEY) + dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, _DIMENSION_KEY) attributes = OrderedDict(attributes) - encoding = {'chunks': zarr_array.chunks, - 'compressor': zarr_array.compressor, - 'filters': zarr_array.filters} + encoding = { + "chunks": zarr_array.chunks, + "compressor": zarr_array.compressor, + "filters": zarr_array.filters, + } # _FillValue needs to be in attributes, not encoding, so it will get # picked up by decode_cf - if getattr(zarr_array, 'fill_value') is not None: - attributes['_FillValue'] = zarr_array.fill_value + if getattr(zarr_array, "fill_value") is not None: + attributes["_FillValue"] = zarr_array.fill_value return Variable(dimensions, data, attributes, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.arrays()) + return FrozenOrderedDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.arrays() + ) def get_attrs(self): attributes = OrderedDict(self.ds.attrs.asdict()) @@ -289,20 +310,24 @@ def get_dimensions(self): for d, s in zip(v.attrs[_DIMENSION_KEY], v.shape): if d in dimensions and dimensions[d] != s: raise ValueError( - 'found conflicting lengths for dimension %s ' - '(%d != %d)' % (d, s, dimensions[d])) + "found conflicting lengths for dimension %s " + "(%d != %d)" % (d, s, dimensions[d]) + ) dimensions[d] = s except KeyError: - raise KeyError("Zarr object is missing the attribute `%s`, " - "which is required for xarray to determine " - "variable dimensions." % (_DIMENSION_KEY)) + raise KeyError( + "Zarr object is missing the attribute `%s`, " + "which is required for xarray to determine " + "variable dimensions." % (_DIMENSION_KEY) + ) return dimensions def set_dimensions(self, variables, unlimited_dims=None): if unlimited_dims is not None: raise NotImplementedError( - "Zarr backend doesn't know how to handle unlimited dimensions") + "Zarr backend doesn't know how to handle unlimited dimensions" + ) def set_attributes(self, attributes): self.ds.attrs.put(attributes) @@ -314,8 +339,14 @@ def encode_variable(self, variable): def encode_attribute(self, a): return _encode_zarr_attr_value(a) - def store(self, variables, attributes, check_encoding_set=frozenset(), - writer=None, unlimited_dims=None): + def store( + self, + variables, + attributes, + check_encoding_set=frozenset(), + writer=None, + unlimited_dims=None, + ): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -340,14 +371,15 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), """ existing_variables = { - vn for vn in variables - if _encode_variable_name(vn) in self.ds + vn for vn in variables if _encode_variable_name(vn) in self.ds } new_variables = set(variables) - existing_variables - variables_without_encoding = OrderedDict([(vn, variables[vn]) - for vn in new_variables]) + variables_without_encoding = OrderedDict( + [(vn, variables[vn]) for vn in new_variables] + ) variables_encoded, attributes = self.encode( - variables_without_encoding, attributes) + variables_without_encoding, attributes + ) if len(existing_variables) > 0: # there are variables to append @@ -357,20 +389,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), for vn in existing_variables: variables_with_encoding[vn] = variables[vn].copy(deep=False) variables_with_encoding[vn].encoding = ds[vn].encoding - variables_with_encoding, _ = self.encode(variables_with_encoding, - {}) + variables_with_encoding, _ = self.encode(variables_with_encoding, {}) variables_encoded.update(variables_with_encoding) self.set_attributes(attributes) self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims) - self.set_variables(variables_encoded, check_encoding_set, writer, - unlimited_dims=unlimited_dims) + self.set_variables( + variables_encoded, check_encoding_set, writer, unlimited_dims=unlimited_dims + ) def sync(self): pass - def set_variables(self, variables, check_encoding_set, writer, - unlimited_dims=None): + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data store. @@ -396,8 +427,8 @@ def set_variables(self, variables, check_encoding_set, writer, dtype = v.dtype shape = v.shape - fill_value = attrs.pop('_FillValue', None) - if v.encoding == {'_FillValue': None} and fill_value is None: + fill_value = attrs.pop("_FillValue", None) + if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} if name in self.ds: zarr_array = self.ds[name] @@ -408,17 +439,12 @@ def set_variables(self, variables, check_encoding_set, writer, new_shape = list(zarr_array.shape) new_shape[append_axis] += v.shape[append_axis] new_region = [slice(None)] * len(new_shape) - new_region[append_axis] = slice( - zarr_array.shape[append_axis], - None - ) + new_region[append_axis] = slice(zarr_array.shape[append_axis], None) zarr_array.resize(new_shape) - writer.add(v.data, zarr_array, - region=tuple(new_region)) + writer.add(v.data, zarr_array, region=tuple(new_region)) else: # new variable - encoding = _extract_zarr_variable_encoding( - v, raise_on_invalid=check) + encoding = _extract_zarr_variable_encoding(v, raise_on_invalid=check) encoded_attrs = OrderedDict() # the magic for storing the hidden dimension data encoded_attrs[_DIMENSION_KEY] = dims @@ -427,22 +453,34 @@ def set_variables(self, variables, check_encoding_set, writer, if coding.strings.check_vlen_dtype(dtype) == str: dtype = str - zarr_array = self.ds.create(name, shape=shape, dtype=dtype, - fill_value=fill_value, **encoding) + zarr_array = self.ds.create( + name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding + ) zarr_array.attrs.put(encoded_attrs) writer.add(v.data, zarr_array) def close(self): if self._consolidate_on_close: import zarr + zarr.consolidate_metadata(self.ds.store) -def open_zarr(store, group=None, synchronizer=None, chunks='auto', - decode_cf=True, mask_and_scale=True, decode_times=True, - concat_characters=True, decode_coords=True, - drop_variables=None, consolidated=False, - overwrite_encoded_chunks=False, **kwargs): +def open_zarr( + store, + group=None, + synchronizer=None, + chunks="auto", + decode_cf=True, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + consolidated=False, + overwrite_encoded_chunks=False, + **kwargs +): """Load and decode a dataset from a Zarr store. .. note:: Experimental @@ -514,24 +552,30 @@ def open_zarr(store, group=None, synchronizer=None, chunks='auto', ---------- http://zarr.readthedocs.io/ """ - if 'auto_chunk' in kwargs: - auto_chunk = kwargs.pop('auto_chunk') + if "auto_chunk" in kwargs: + auto_chunk = kwargs.pop("auto_chunk") if auto_chunk: - chunks = 'auto' # maintain backwards compatibility + chunks = "auto" # maintain backwards compatibility else: chunks = None - warnings.warn("auto_chunk is deprecated. Use chunks='auto' instead.", - FutureWarning, stacklevel=2) + warnings.warn( + "auto_chunk is deprecated. Use chunks='auto' instead.", + FutureWarning, + stacklevel=2, + ) if kwargs: - raise TypeError("open_zarr() got unexpected keyword arguments " + - ",".join(kwargs.keys())) + raise TypeError( + "open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys()) + ) if not isinstance(chunks, (int, dict)): - if chunks != 'auto' and chunks is not None: - raise ValueError("chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks) + if chunks != "auto" and chunks is not None: + raise ValueError( + "chunks must be an int, dict, 'auto', or None. " + "Instead found %s. " % chunks + ) if not decode_cf: mask_and_scale = False @@ -541,9 +585,13 @@ def open_zarr(store, group=None, synchronizer=None, chunks='auto', def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( - store, mask_and_scale=mask_and_scale, decode_times=decode_times, - concat_characters=concat_characters, decode_coords=decode_coords, - drop_variables=drop_variables) + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + ) # TODO: this is where we would apply caching @@ -551,10 +599,14 @@ def maybe_decode_store(store, lock=False): # Zarr supports a wide range of access modes, but for now xarray either # reads or writes from a store, never both. For open_zarr, we only read - mode = 'r' - zarr_store = ZarrStore.open_group(store, mode=mode, - synchronizer=synchronizer, - group=group, consolidated=consolidated) + mode = "r" + zarr_store = ZarrStore.open_group( + store, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=consolidated, + ) ds = maybe_decode_store(zarr_store) # auto chunking needs to be here and not in ZarrStore because variable @@ -571,13 +623,13 @@ def maybe_decode_store(store, lock=False): chunks = dict(zip(ds.dims, chunks)) def get_chunk(name, var, chunks): - chunk_spec = dict(zip(var.dims, var.encoding.get('chunks'))) + chunk_spec = dict(zip(var.dims, var.encoding.get("chunks"))) # Coordinate labels aren't chunked if var.ndim == 1 and var.dims[0] == name: return chunk_spec - if chunks == 'auto': + if chunks == "auto": return chunk_spec for dim in var.dims: @@ -587,13 +639,15 @@ def get_chunk(name, var, chunks): spec = (spec,) if isinstance(spec, (tuple, list)) and chunk_spec[dim]: if any(s % chunk_spec[dim] for s in spec): - warnings.warn("Specified Dask chunks %r would " - "separate Zarr chunk shape %r for " - "dimension %r. This significantly " - "degrades performance. Consider " - "rechunking after loading instead." - % (chunks[dim], chunk_spec[dim], dim), - stacklevel=2) + warnings.warn( + "Specified Dask chunks %r would " + "separate Zarr chunk shape %r for " + "dimension %r. This significantly " + "degrades performance. Consider " + "rechunking after loading instead." + % (chunks[dim], chunk_spec[dim], dim), + stacklevel=2, + ) chunk_spec[dim] = chunks[dim] return chunk_spec @@ -605,14 +659,15 @@ def maybe_chunk(name, var, chunks): if (var.ndim > 0) and (chunk_spec is not None): # does this cause any data to be read? token2 = tokenize(name, var._data) - name2 = 'zarr-%s' % token2 + name2 = "zarr-%s" % token2 var = var.chunk(chunk_spec, name=name2, lock=None) if overwrite_encoded_chunks and var.chunks is not None: - var.encoding['chunks'] = tuple(x[0] for x in var.chunks) + var.encoding["chunks"] = tuple(x[0] for x in var.chunks) return var else: return var - variables = OrderedDict([(k, maybe_chunk(k, v, chunks)) - for k, v in ds.variables.items()]) + variables = OrderedDict( + [(k, maybe_chunk(k, v, chunks)) for k, v in ds.variables.items()] + ) return ds._replace_vars_and_dims(variables) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 7187f1266bd..a26f13df924 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -56,19 +56,18 @@ def get_date_type(calendar): try: import cftime except ImportError: - raise ImportError( - 'cftime is required for dates with non-standard calendars') + raise ImportError("cftime is required for dates with non-standard calendars") else: calendars = { - 'noleap': cftime.DatetimeNoLeap, - '360_day': cftime.Datetime360Day, - '365_day': cftime.DatetimeNoLeap, - '366_day': cftime.DatetimeAllLeap, - 'gregorian': cftime.DatetimeGregorian, - 'proleptic_gregorian': cftime.DatetimeProlepticGregorian, - 'julian': cftime.DatetimeJulian, - 'all_leap': cftime.DatetimeAllLeap, - 'standard': cftime.DatetimeGregorian + "noleap": cftime.DatetimeNoLeap, + "360_day": cftime.Datetime360Day, + "365_day": cftime.DatetimeNoLeap, + "366_day": cftime.DatetimeAllLeap, + "gregorian": cftime.DatetimeGregorian, + "proleptic_gregorian": cftime.DatetimeProlepticGregorian, + "julian": cftime.DatetimeJulian, + "all_leap": cftime.DatetimeAllLeap, + "standard": cftime.DatetimeGregorian, } return calendars[calendar] @@ -81,7 +80,8 @@ def __init__(self, n=1): if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " - "Instead a value of type {!r} was provided.".format(type(n))) + "Instead a value of type {!r} was provided.".format(type(n)) + ) self.n = n def rule_code(self): @@ -100,8 +100,7 @@ def __sub__(self, other): import cftime if isinstance(other, cftime.datetime): - raise TypeError('Cannot subtract a cftime.datetime ' - 'from a time offset.') + raise TypeError("Cannot subtract a cftime.datetime " "from a time offset.") elif type(other) == type(self): return type(self)(self.n - other.n) else: @@ -121,8 +120,7 @@ def __radd__(self, other): def __rsub__(self, other): if isinstance(other, BaseCFTimeOffset) and type(self) != type(other): - raise TypeError('Cannot subtract cftime offsets of differing ' - 'types') + raise TypeError("Cannot subtract cftime offsets of differing " "types") return -self + other def __apply__(self): @@ -147,7 +145,7 @@ def rollback(self, date): return date - type(self)() def __str__(self): - return '<{}: n={}>'.format(type(self).__name__, self.n) + return "<{}: n={}>".format(type(self).__name__, self.n) def __repr__(self): return str(self) @@ -175,9 +173,9 @@ def _get_day_of_month(other, day_option): """ - if day_option == 'start': + if day_option == "start": return 1 - elif day_option == 'end': + elif day_option == "end": days_in_month = _days_in_month(other) return days_in_month elif day_option is None: @@ -212,17 +210,15 @@ def _adjust_n_years(other, n, month, reference_day): """Adjust the number of times an annual offset is applied based on another date, and the reference day provided""" if n > 0: - if other.month < month or (other.month == month and - other.day < reference_day): + if other.month < month or (other.month == month and other.day < reference_day): n -= 1 else: - if other.month > month or (other.month == month and - other.day > reference_day): + if other.month > month or (other.month == month and other.day > reference_day): n += 1 return n -def _shift_month(date, months, day_option='start'): +def _shift_month(date, months, day_option="start"): """Shift the date to a month start or end a given number of months away. """ delta_year = (date.month + months) // 12 @@ -233,9 +229,9 @@ def _shift_month(date, months, day_option='start'): delta_year = delta_year - 1 year = date.year + delta_year - if day_option == 'start': + if day_option == "start": day = 1 - elif day_option == 'end': + elif day_option == "end": reference = type(date)(year, month, 1) day = _days_in_month(reference) else: @@ -274,15 +270,15 @@ def roll_qtrday(other, n, month, day_option, modby=3): if n > 0: if months_since < 0 or ( - months_since == 0 and - other.day < _get_day_of_month(other, day_option)): + months_since == 0 and other.day < _get_day_of_month(other, day_option) + ): # pretend to roll back if on same month but # before compare_day n -= 1 else: if months_since > 0 or ( - months_since == 0 and - other.day > _get_day_of_month(other, day_option)): + months_since == 0 and other.day > _get_day_of_month(other, day_option) + ): # make sure to roll forward, so negate n += 1 return n @@ -294,22 +290,26 @@ def _validate_month(month, default_month): else: result_month = month if not isinstance(result_month, int): - raise TypeError("'self.month' must be an integer value between 1 " - "and 12. Instead, it was set to a value of " - "{!r}".format(result_month)) + raise TypeError( + "'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(result_month) + ) elif not (1 <= result_month <= 12): - raise ValueError("'self.month' must be an integer value between 1 " - "and 12. Instead, it was set to a value of " - "{!r}".format(result_month)) + raise ValueError( + "'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(result_month) + ) return result_month class MonthBegin(BaseCFTimeOffset): - _freq = 'MS' + _freq = "MS" def __apply__(self, other): n = _adjust_n_months(other.day, self.n, 1) - return _shift_month(other, n, 'start') + return _shift_month(other, n, "start") def onOffset(self, date): """Check if the given date is in the set of possible dates created @@ -318,11 +318,11 @@ def onOffset(self, date): class MonthEnd(BaseCFTimeOffset): - _freq = 'M' + _freq = "M" def __apply__(self, other): n = _adjust_n_months(other.day, self.n, _days_in_month(other)) - return _shift_month(other, n, 'end') + return _shift_month(other, n, "end") def onOffset(self, date): """Check if the given date is in the set of possible dates created @@ -331,24 +331,25 @@ def onOffset(self, date): _MONTH_ABBREVIATIONS = { - 1: 'JAN', - 2: 'FEB', - 3: 'MAR', - 4: 'APR', - 5: 'MAY', - 6: 'JUN', - 7: 'JUL', - 8: 'AUG', - 9: 'SEP', - 10: 'OCT', - 11: 'NOV', - 12: 'DEC' + 1: "JAN", + 2: "FEB", + 3: "MAR", + 4: "APR", + 5: "MAY", + 6: "JUN", + 7: "JUL", + 8: "AUG", + 9: "SEP", + 10: "OCT", + 11: "NOV", + 12: "DEC", } class QuarterOffset(BaseCFTimeOffset): """Quarter representation copied off of pandas/tseries/offsets.py """ + _freq = None # type: ClassVar[str] _default_month = None # type: ClassVar[int] @@ -363,8 +364,9 @@ def __apply__(self, other): # self. `months_since` is the number of months to shift other.month # to get to this on-offset month. months_since = other.month % 3 - self.month % 3 - qtrs = roll_qtrday(other, self.n, self.month, - day_option=self._day_option, modby=3) + qtrs = roll_qtrday( + other, self.n, self.month, day_option=self._day_option, modby=3 + ) months = qtrs * 3 - months_since return _shift_month(other, months, self._day_option) @@ -378,7 +380,7 @@ def __sub__(self, other): import cftime if isinstance(other, cftime.datetime): - raise TypeError('Cannot subtract cftime.datetime from offset.') + raise TypeError("Cannot subtract cftime.datetime from offset.") elif type(other) == type(self) and other.month == self.month: return type(self)(self.n - other.n, month=self.month) else: @@ -388,11 +390,10 @@ def __mul__(self, other): return type(self)(n=other * self.n, month=self.month) def rule_code(self): - return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month]) + return "{}-{}".format(self._freq, _MONTH_ABBREVIATIONS[self.month]) def __str__(self): - return '<{}: n={}, month={}>'.format( - type(self).__name__, self.n, self.month) + return "<{}: n={}, month={}>".format(type(self).__name__, self.n, self.month) class QuarterBegin(QuarterOffset): @@ -402,8 +403,8 @@ class QuarterBegin(QuarterOffset): # from the constructor, however, the default month is March. # We follow that behavior here. _default_month = 3 - _freq = 'QS' - _day_option = 'start' + _freq = "QS" + _day_option = "start" def rollforward(self, date): """Roll date forward to nearest start of quarter""" @@ -427,8 +428,8 @@ class QuarterEnd(QuarterOffset): # from the constructor, however, the default month is March. # We follow that behavior here. _default_month = 3 - _freq = 'Q' - _day_option = 'end' + _freq = "Q" + _day_option = "end" def rollforward(self, date): """Roll date forward to nearest end of quarter""" @@ -464,7 +465,7 @@ def __sub__(self, other): import cftime if isinstance(other, cftime.datetime): - raise TypeError('Cannot subtract cftime.datetime from offset.') + raise TypeError("Cannot subtract cftime.datetime from offset.") elif type(other) == type(self) and other.month == self.month: return type(self)(self.n - other.n, month=self.month) else: @@ -474,16 +475,15 @@ def __mul__(self, other): return type(self)(n=other * self.n, month=self.month) def rule_code(self): - return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month]) + return "{}-{}".format(self._freq, _MONTH_ABBREVIATIONS[self.month]) def __str__(self): - return '<{}: n={}, month={}>'.format( - type(self).__name__, self.n, self.month) + return "<{}: n={}, month={}>".format(type(self).__name__, self.n, self.month) class YearBegin(YearOffset): - _freq = 'AS' - _day_option = 'start' + _freq = "AS" + _day_option = "start" _default_month = 1 def onOffset(self, date): @@ -507,8 +507,8 @@ def rollback(self, date): class YearEnd(YearOffset): - _freq = 'A' - _day_option = 'end' + _freq = "A" + _day_option = "end" _default_month = 12 def onOffset(self, date): @@ -532,7 +532,7 @@ def rollback(self, date): class Day(BaseCFTimeOffset): - _freq = 'D' + _freq = "D" def as_timedelta(self): return timedelta(days=self.n) @@ -542,7 +542,7 @@ def __apply__(self, other): class Hour(BaseCFTimeOffset): - _freq = 'H' + _freq = "H" def as_timedelta(self): return timedelta(hours=self.n) @@ -552,7 +552,7 @@ def __apply__(self, other): class Minute(BaseCFTimeOffset): - _freq = 'T' + _freq = "T" def as_timedelta(self): return timedelta(minutes=self.n) @@ -562,7 +562,7 @@ def __apply__(self, other): class Second(BaseCFTimeOffset): - _freq = 'S' + _freq = "S" def as_timedelta(self): return timedelta(seconds=self.n) @@ -572,73 +572,72 @@ def __apply__(self, other): _FREQUENCIES = { - 'A': YearEnd, - 'AS': YearBegin, - 'Y': YearEnd, - 'YS': YearBegin, - 'Q': partial(QuarterEnd, month=12), - 'QS': partial(QuarterBegin, month=1), - 'M': MonthEnd, - 'MS': MonthBegin, - 'D': Day, - 'H': Hour, - 'T': Minute, - 'min': Minute, - 'S': Second, - 'AS-JAN': partial(YearBegin, month=1), - 'AS-FEB': partial(YearBegin, month=2), - 'AS-MAR': partial(YearBegin, month=3), - 'AS-APR': partial(YearBegin, month=4), - 'AS-MAY': partial(YearBegin, month=5), - 'AS-JUN': partial(YearBegin, month=6), - 'AS-JUL': partial(YearBegin, month=7), - 'AS-AUG': partial(YearBegin, month=8), - 'AS-SEP': partial(YearBegin, month=9), - 'AS-OCT': partial(YearBegin, month=10), - 'AS-NOV': partial(YearBegin, month=11), - 'AS-DEC': partial(YearBegin, month=12), - 'A-JAN': partial(YearEnd, month=1), - 'A-FEB': partial(YearEnd, month=2), - 'A-MAR': partial(YearEnd, month=3), - 'A-APR': partial(YearEnd, month=4), - 'A-MAY': partial(YearEnd, month=5), - 'A-JUN': partial(YearEnd, month=6), - 'A-JUL': partial(YearEnd, month=7), - 'A-AUG': partial(YearEnd, month=8), - 'A-SEP': partial(YearEnd, month=9), - 'A-OCT': partial(YearEnd, month=10), - 'A-NOV': partial(YearEnd, month=11), - 'A-DEC': partial(YearEnd, month=12), - 'QS-JAN': partial(QuarterBegin, month=1), - 'QS-FEB': partial(QuarterBegin, month=2), - 'QS-MAR': partial(QuarterBegin, month=3), - 'QS-APR': partial(QuarterBegin, month=4), - 'QS-MAY': partial(QuarterBegin, month=5), - 'QS-JUN': partial(QuarterBegin, month=6), - 'QS-JUL': partial(QuarterBegin, month=7), - 'QS-AUG': partial(QuarterBegin, month=8), - 'QS-SEP': partial(QuarterBegin, month=9), - 'QS-OCT': partial(QuarterBegin, month=10), - 'QS-NOV': partial(QuarterBegin, month=11), - 'QS-DEC': partial(QuarterBegin, month=12), - 'Q-JAN': partial(QuarterEnd, month=1), - 'Q-FEB': partial(QuarterEnd, month=2), - 'Q-MAR': partial(QuarterEnd, month=3), - 'Q-APR': partial(QuarterEnd, month=4), - 'Q-MAY': partial(QuarterEnd, month=5), - 'Q-JUN': partial(QuarterEnd, month=6), - 'Q-JUL': partial(QuarterEnd, month=7), - 'Q-AUG': partial(QuarterEnd, month=8), - 'Q-SEP': partial(QuarterEnd, month=9), - 'Q-OCT': partial(QuarterEnd, month=10), - 'Q-NOV': partial(QuarterEnd, month=11), - 'Q-DEC': partial(QuarterEnd, month=12) + "A": YearEnd, + "AS": YearBegin, + "Y": YearEnd, + "YS": YearBegin, + "Q": partial(QuarterEnd, month=12), + "QS": partial(QuarterBegin, month=1), + "M": MonthEnd, + "MS": MonthBegin, + "D": Day, + "H": Hour, + "T": Minute, + "min": Minute, + "S": Second, + "AS-JAN": partial(YearBegin, month=1), + "AS-FEB": partial(YearBegin, month=2), + "AS-MAR": partial(YearBegin, month=3), + "AS-APR": partial(YearBegin, month=4), + "AS-MAY": partial(YearBegin, month=5), + "AS-JUN": partial(YearBegin, month=6), + "AS-JUL": partial(YearBegin, month=7), + "AS-AUG": partial(YearBegin, month=8), + "AS-SEP": partial(YearBegin, month=9), + "AS-OCT": partial(YearBegin, month=10), + "AS-NOV": partial(YearBegin, month=11), + "AS-DEC": partial(YearBegin, month=12), + "A-JAN": partial(YearEnd, month=1), + "A-FEB": partial(YearEnd, month=2), + "A-MAR": partial(YearEnd, month=3), + "A-APR": partial(YearEnd, month=4), + "A-MAY": partial(YearEnd, month=5), + "A-JUN": partial(YearEnd, month=6), + "A-JUL": partial(YearEnd, month=7), + "A-AUG": partial(YearEnd, month=8), + "A-SEP": partial(YearEnd, month=9), + "A-OCT": partial(YearEnd, month=10), + "A-NOV": partial(YearEnd, month=11), + "A-DEC": partial(YearEnd, month=12), + "QS-JAN": partial(QuarterBegin, month=1), + "QS-FEB": partial(QuarterBegin, month=2), + "QS-MAR": partial(QuarterBegin, month=3), + "QS-APR": partial(QuarterBegin, month=4), + "QS-MAY": partial(QuarterBegin, month=5), + "QS-JUN": partial(QuarterBegin, month=6), + "QS-JUL": partial(QuarterBegin, month=7), + "QS-AUG": partial(QuarterBegin, month=8), + "QS-SEP": partial(QuarterBegin, month=9), + "QS-OCT": partial(QuarterBegin, month=10), + "QS-NOV": partial(QuarterBegin, month=11), + "QS-DEC": partial(QuarterBegin, month=12), + "Q-JAN": partial(QuarterEnd, month=1), + "Q-FEB": partial(QuarterEnd, month=2), + "Q-MAR": partial(QuarterEnd, month=3), + "Q-APR": partial(QuarterEnd, month=4), + "Q-MAY": partial(QuarterEnd, month=5), + "Q-JUN": partial(QuarterEnd, month=6), + "Q-JUL": partial(QuarterEnd, month=7), + "Q-AUG": partial(QuarterEnd, month=8), + "Q-SEP": partial(QuarterEnd, month=9), + "Q-OCT": partial(QuarterEnd, month=10), + "Q-NOV": partial(QuarterEnd, month=11), + "Q-DEC": partial(QuarterEnd, month=12), } -_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys()) -_PATTERN = r'^((?P\d+)|())(?P({}))$'.format( - _FREQUENCY_CONDITION) +_FREQUENCY_CONDITION = "|".join(_FREQUENCIES.keys()) +_PATTERN = r"^((?P\d+)|())(?P({}))$".format(_FREQUENCY_CONDITION) # pandas defines these offsets as "Tick" objects, which for instance have @@ -655,10 +654,10 @@ def to_offset(freq): try: freq_data = re.match(_PATTERN, freq).groupdict() except AttributeError: - raise ValueError('Invalid frequency string provided') + raise ValueError("Invalid frequency string provided") - freq = freq_data['freq'] - multiples = freq_data['multiple'] + freq = freq_data["freq"] + multiples = freq_data["multiple"] if multiples is None: multiples = 1 else: @@ -673,17 +672,19 @@ def to_cftime_datetime(date_str_or_date, calendar=None): if isinstance(date_str_or_date, str): if calendar is None: raise ValueError( - 'If converting a string to a cftime.datetime object, ' - 'a calendar type must be provided') - date, _ = _parse_iso8601_with_reso(get_date_type(calendar), - date_str_or_date) + "If converting a string to a cftime.datetime object, " + "a calendar type must be provided" + ) + date, _ = _parse_iso8601_with_reso(get_date_type(calendar), date_str_or_date) return date elif isinstance(date_str_or_date, cftime.datetime): return date_str_or_date else: - raise TypeError("date_str_or_date must be a string or a " - 'subclass of cftime.datetime. Instead got ' - '{!r}.'.format(date_str_or_date)) + raise TypeError( + "date_str_or_date must be a string or a " + "subclass of cftime.datetime. Instead got " + "{!r}.".format(date_str_or_date) + ) def normalize_date(date): @@ -705,11 +706,12 @@ def _generate_linear_range(start, end, periods): import cftime total_seconds = (end - start).total_seconds() - values = np.linspace(0., total_seconds, periods, endpoint=True) - units = 'seconds since {}'.format(format_cftime_datetime(start)) + values = np.linspace(0.0, total_seconds, periods, endpoint=True) + units = "seconds since {}".format(format_cftime_datetime(start)) calendar = start.calendar - return cftime.num2date(values, units=units, calendar=calendar, - only_use_cftime_datetimes=True) + return cftime.num2date( + values, units=units, calendar=calendar, only_use_cftime_datetimes=True + ) def _generate_range(start, end, periods, offset): @@ -756,8 +758,9 @@ def _generate_range(start, end, periods, offset): next_date = current + offset if next_date <= current: - raise ValueError('Offset {offset} did not increment date' - .format(offset=offset)) + raise ValueError( + "Offset {offset} did not increment date".format(offset=offset) + ) current = next_date else: while current >= end: @@ -765,8 +768,9 @@ def _generate_range(start, end, periods, offset): next_date = current + offset if next_date >= current: - raise ValueError('Offset {offset} did not decrement date' - .format(offset=offset)) + raise ValueError( + "Offset {offset} did not decrement date".format(offset=offset) + ) current = next_date @@ -775,9 +779,16 @@ def _count_not_none(*args): return sum([arg is not None for arg in args]) -def cftime_range(start=None, end=None, periods=None, freq='D', - normalize=False, name=None, closed=None, - calendar='standard'): +def cftime_range( + start=None, + end=None, + periods=None, + freq="D", + normalize=False, + name=None, + closed=None, + calendar="standard", +): """Return a fixed frequency CFTimeIndex. Parameters @@ -949,7 +960,8 @@ def cftime_range(start=None, end=None, periods=None, freq='D', if _count_not_none(start, end, periods, freq) != 3: raise ValueError( "Of the arguments 'start', 'end', 'periods', and 'freq', three " - "must be specified at a time.") + "must be specified at a time." + ) if start is not None: start = to_cftime_datetime(start, calendar) @@ -970,18 +982,16 @@ def cftime_range(start=None, end=None, periods=None, freq='D', if closed is None: left_closed = True right_closed = True - elif closed == 'left': + elif closed == "left": left_closed = True - elif closed == 'right': + elif closed == "right": right_closed = True else: raise ValueError("Closed must be either 'left', 'right' or None") - if (not left_closed and len(dates) and - start is not None and dates[0] == start): + if not left_closed and len(dates) and start is not None and dates[0] == start: dates = dates[1:] - if (not right_closed and len(dates) and - end is not None and dates[-1] == end): + if not right_closed and len(dates) and end is not None and dates[-1] == end: dates = dates[:-1] return CFTimeIndex(dates, name=name) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index cf10d6238aa..0150eff2c1e 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -53,34 +53,36 @@ def named(name, pattern): - return '(?P<' + name + '>' + pattern + ')' + return "(?P<" + name + ">" + pattern + ")" def optional(x): - return '(?:' + x + ')?' + return "(?:" + x + ")?" def trailing_optional(xs): if not xs: - return '' + return "" return xs[0] + optional(trailing_optional(xs[1:])) -def build_pattern(date_sep=r'\-', datetime_sep=r'T', time_sep=r'\:'): - pieces = [(None, 'year', r'\d{4}'), - (date_sep, 'month', r'\d{2}'), - (date_sep, 'day', r'\d{2}'), - (datetime_sep, 'hour', r'\d{2}'), - (time_sep, 'minute', r'\d{2}'), - (time_sep, 'second', r'\d{2}')] +def build_pattern(date_sep=r"\-", datetime_sep=r"T", time_sep=r"\:"): + pieces = [ + (None, "year", r"\d{4}"), + (date_sep, "month", r"\d{2}"), + (date_sep, "day", r"\d{2}"), + (datetime_sep, "hour", r"\d{2}"), + (time_sep, "minute", r"\d{2}"), + (time_sep, "second", r"\d{2}"), + ] pattern_list = [] for sep, name, sub_pattern in pieces: - pattern_list.append((sep if sep else '') + named(name, sub_pattern)) + pattern_list.append((sep if sep else "") + named(name, sub_pattern)) # TODO: allow timezone offsets? - return '^' + trailing_optional(pattern_list) + '$' + return "^" + trailing_optional(pattern_list) + "$" -_BASIC_PATTERN = build_pattern(date_sep='', time_sep='') +_BASIC_PATTERN = build_pattern(date_sep="", time_sep="") _EXTENDED_PATTERN = build_pattern() _PATTERNS = [_BASIC_PATTERN, _EXTENDED_PATTERN] @@ -90,7 +92,7 @@ def parse_iso8601(datetime_string): match = re.match(pattern, datetime_string) if match: return match.groupdict() - raise ValueError('no ISO-8601 match for string: %s' % datetime_string) + raise ValueError("no ISO-8601 match for string: %s" % datetime_string) def _parse_iso8601_with_reso(date_type, timestr): @@ -98,7 +100,7 @@ def _parse_iso8601_with_reso(date_type, timestr): result = parse_iso8601(timestr) replace = {} - for attr in ['year', 'month', 'day', 'hour', 'minute', 'second']: + for attr in ["year", "month", "day", "hour", "minute", "second"]: value = result.get(attr, None) if value is not None: # Note ISO8601 conventions allow for fractional seconds. @@ -110,7 +112,7 @@ def _parse_iso8601_with_reso(date_type, timestr): # the returned date object in versions of cftime between 1.0.2 and # 1.0.3.4. It can be removed for versions of cftime greater than # 1.0.3.4. - replace['dayofwk'] = -1 + replace["dayofwk"] = -1 return default.replace(**replace), resolution @@ -120,29 +122,39 @@ def _parsed_string_to_bounds(date_type, resolution, parsed): for use with non-standard calendars and cftime.datetime objects. """ - if resolution == 'year': - return (date_type(parsed.year, 1, 1), - date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1)) - elif resolution == 'month': + if resolution == "year": + return ( + date_type(parsed.year, 1, 1), + date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1), + ) + elif resolution == "month": if parsed.month == 12: end = date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1) else: - end = (date_type(parsed.year, parsed.month + 1, 1) - - timedelta(microseconds=1)) + end = date_type(parsed.year, parsed.month + 1, 1) - timedelta( + microseconds=1 + ) return date_type(parsed.year, parsed.month, 1), end - elif resolution == 'day': + elif resolution == "day": start = date_type(parsed.year, parsed.month, parsed.day) return start, start + timedelta(days=1, microseconds=-1) - elif resolution == 'hour': + elif resolution == "hour": start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour) return start, start + timedelta(hours=1, microseconds=-1) - elif resolution == 'minute': - start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour, - parsed.minute) + elif resolution == "minute": + start = date_type( + parsed.year, parsed.month, parsed.day, parsed.hour, parsed.minute + ) return start, start + timedelta(minutes=1, microseconds=-1) - elif resolution == 'second': - start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour, - parsed.minute, parsed.second) + elif resolution == "second": + start = date_type( + parsed.year, + parsed.month, + parsed.day, + parsed.hour, + parsed.minute, + parsed.second, + ) return start, start + timedelta(seconds=1, microseconds=-1) else: raise KeyError @@ -153,7 +165,7 @@ def get_date_field(datetimes, field): return np.array([getattr(date, field) for date in datetimes]) -def _field_accessor(name, docstring=None, min_cftime_version='0.0'): +def _field_accessor(name, docstring=None, min_cftime_version="0.0"): """Adapted from pandas.tseries.index._field_accessor""" def f(self, min_cftime_version=min_cftime_version): @@ -164,10 +176,11 @@ def f(self, min_cftime_version=min_cftime_version): if LooseVersion(version) >= LooseVersion(min_cftime_version): return get_date_field(self._data, name) else: - raise ImportError('The {!r} accessor requires a minimum ' - 'version of cftime of {}. Found an ' - 'installed version of {}.'.format( - name, min_cftime_version, version)) + raise ImportError( + "The {!r} accessor requires a minimum " + "version of cftime of {}. Found an " + "installed version of {}.".format(name, min_cftime_version, version) + ) f.__name__ = name f.__doc__ = docstring @@ -189,12 +202,14 @@ def assert_all_valid_date_type(data): date_type = type(sample) if not isinstance(sample, cftime.datetime): raise TypeError( - 'CFTimeIndex requires cftime.datetime ' - 'objects. Got object of {}.'.format(date_type)) + "CFTimeIndex requires cftime.datetime " + "objects. Got object of {}.".format(date_type) + ) if not all(isinstance(value, date_type) for value in data): raise TypeError( - 'CFTimeIndex requires using datetime ' - 'objects of all the same type. Got\n{}.'.format(data)) + "CFTimeIndex requires using datetime " + "objects of all the same type. Got\n{}.".format(data) + ) class CFTimeIndex(pd.Index): @@ -213,28 +228,27 @@ class CFTimeIndex(pd.Index): -------- cftime_range """ - year = _field_accessor('year', 'The year of the datetime') - month = _field_accessor('month', 'The month of the datetime') - day = _field_accessor('day', 'The days of the datetime') - hour = _field_accessor('hour', 'The hours of the datetime') - minute = _field_accessor('minute', 'The minutes of the datetime') - second = _field_accessor('second', 'The seconds of the datetime') - microsecond = _field_accessor('microsecond', - 'The microseconds of the datetime') - dayofyear = _field_accessor('dayofyr', - 'The ordinal day of year of the datetime', - '1.0.2.1') - dayofweek = _field_accessor('dayofwk', 'The day of week of the datetime', - '1.0.2.1') + + year = _field_accessor("year", "The year of the datetime") + month = _field_accessor("month", "The month of the datetime") + day = _field_accessor("day", "The days of the datetime") + hour = _field_accessor("hour", "The hours of the datetime") + minute = _field_accessor("minute", "The minutes of the datetime") + second = _field_accessor("second", "The seconds of the datetime") + microsecond = _field_accessor("microsecond", "The microseconds of the datetime") + dayofyear = _field_accessor( + "dayofyr", "The ordinal day of year of the datetime", "1.0.2.1" + ) + dayofweek = _field_accessor("dayofwk", "The day of week of the datetime", "1.0.2.1") date_type = property(get_date_type) def __new__(cls, data, name=None): assert_all_valid_date_type(data) - if name is None and hasattr(data, 'name'): + if name is None and hasattr(data, "name"): name = data.name result = object.__new__(cls) - result._data = np.array(data, dtype='O') + result._data = np.array(data, dtype="O") result.name = name return result @@ -280,20 +294,21 @@ def _partial_date_slice(self, resolution, parsed): Coordinates: * time (time) datetime64[ns] 2001-01-01T01:00:00 """ - start, end = _parsed_string_to_bounds(self.date_type, resolution, - parsed) + start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) times = self._data if self.is_monotonic: - if (len(times) and ((start < times[0] and end < times[0]) or - (start > times[-1] and end > times[-1]))): + if len(times) and ( + (start < times[0] and end < times[0]) + or (start > times[-1] and end > times[-1]) + ): # we are out of range raise KeyError # a monotonic (sorted) series can be sliced - left = times.searchsorted(start, side='left') - right = times.searchsorted(end, side='right') + left = times.searchsorted(start, side="left") + right = times.searchsorted(end, side="right") return slice(left, right) lhs_mask = times >= start @@ -314,20 +329,17 @@ def get_loc(self, key, method=None, tolerance=None): if isinstance(key, str): return self._get_string_slice(key) else: - return pd.Index.get_loc(self, key, method=method, - tolerance=tolerance) + return pd.Index.get_loc(self, key, method=method, tolerance=tolerance) def _maybe_cast_slice_bound(self, label, side, kind): """Adapted from pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound""" if isinstance(label, str): - parsed, resolution = _parse_iso8601_with_reso(self.date_type, - label) - start, end = _parsed_string_to_bounds(self.date_type, resolution, - parsed) + parsed, resolution = _parse_iso8601_with_reso(self.date_type, label) + start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) if self.is_monotonic_decreasing and len(self) > 1: - return end if side == 'left' else start - return start if side == 'left' else end + return end if side == "left" else start + return start if side == "left" else end else: return label @@ -338,8 +350,7 @@ def get_value(self, series, key): if np.asarray(key).dtype == np.dtype(bool): return series.iloc[key] elif isinstance(key, slice): - return series.iloc[self.slice_indexer( - key.start, key.stop, key.step)] + return series.iloc[self.slice_indexer(key.start, key.stop, key.step)] else: return series.iloc[self.get_loc(key)] @@ -348,8 +359,11 @@ def __contains__(self, key): pandas.tseries.base.DatetimeIndexOpsMixin.__contains__""" try: result = self.get_loc(key) - return (is_scalar(result) or type(result) == slice or - (isinstance(result, np.ndarray) and result.size)) + return ( + is_scalar(result) + or type(result) == slice + or (isinstance(result, np.ndarray) and result.size) + ) except (KeyError, TypeError, ValueError): return False @@ -397,7 +411,8 @@ def shift(self, n, freq): else: raise TypeError( "'freq' must be of type " - "str or datetime.timedelta, got {}.".format(freq)) + "str or datetime.timedelta, got {}.".format(freq) + ) def __add__(self, other): if isinstance(other, pd.TimedeltaIndex): @@ -411,6 +426,7 @@ def __radd__(self, other): def __sub__(self, other): import cftime + if isinstance(other, (CFTimeIndex, cftime.datetime)): return pd.TimedeltaIndex(np.array(self) - np.array(other)) elif isinstance(other, pd.TimedeltaIndex): @@ -469,11 +485,14 @@ def to_datetimeindex(self, unsafe=False): calendar = infer_calendar_name(self) if calendar not in _STANDARD_CALENDARS and not unsafe: warnings.warn( - 'Converting a CFTimeIndex with dates from a non-standard ' - 'calendar, {!r}, to a pandas.DatetimeIndex, which uses dates ' - 'from the standard calendar. This may lead to subtle errors ' - 'in operations that depend on the length of time between ' - 'dates.'.format(calendar), RuntimeWarning, stacklevel=2) + "Converting a CFTimeIndex with dates from a non-standard " + "calendar, {!r}, to a pandas.DatetimeIndex, which uses dates " + "from the standard calendar. This may lead to subtle errors " + "in operations that depend on the length of time between " + "dates.".format(calendar), + RuntimeWarning, + stacklevel=2, + ) return pd.DatetimeIndex(nptimes) def strftime(self, date_format): @@ -528,5 +547,6 @@ def _parse_array_of_cftime_strings(strings, date_type): ------- np.array """ - return np.array([_parse_iso8601_without_reso(date_type, s) - for s in strings.ravel()]).reshape(strings.shape) + return np.array( + [_parse_iso8601_without_reso(date_type, s) for s in strings.ravel()] + ).reshape(strings.shape) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 007bcb8a502..44d07929e35 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -7,28 +7,33 @@ from ..core.pycompat import dask_array_type from ..core.variable import Variable from .variables import ( - VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, - unpack_for_decoding, unpack_for_encoding) + VariableCoder, + lazy_elemwise_func, + pop_to, + safe_setitem, + unpack_for_decoding, + unpack_for_encoding, +) def create_vlen_dtype(element_type): # based on h5py.special_dtype - return np.dtype('O', metadata={'element_type': element_type}) + return np.dtype("O", metadata={"element_type": element_type}) def check_vlen_dtype(dtype): - if dtype.kind != 'O' or dtype.metadata is None: + if dtype.kind != "O" or dtype.metadata is None: return None else: - return dtype.metadata.get('element_type') + return dtype.metadata.get("element_type") def is_unicode_dtype(dtype): - return dtype.kind == 'U' or check_vlen_dtype(dtype) == str + return dtype.kind == "U" or check_vlen_dtype(dtype) == str def is_bytes_dtype(dtype): - return dtype.kind == 'S' or check_vlen_dtype(dtype) == bytes + return dtype.kind == "S" or check_vlen_dtype(dtype) == bytes class EncodedStringCoder(VariableCoder): @@ -41,21 +46,21 @@ def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) contains_unicode = is_unicode_dtype(data.dtype) - encode_as_char = encoding.get('dtype') == 'S1' + encode_as_char = encoding.get("dtype") == "S1" if encode_as_char: - del encoding['dtype'] # no longer relevant + del encoding["dtype"] # no longer relevant if contains_unicode and (encode_as_char or not self.allows_unicode): - if '_FillValue' in attrs: + if "_FillValue" in attrs: raise NotImplementedError( - 'variable {!r} has a _FillValue specified, but ' - '_FillValue is not yet supported on unicode strings: ' - 'https://github.com/pydata/xarray/issues/1647' - .format(name)) + "variable {!r} has a _FillValue specified, but " + "_FillValue is not yet supported on unicode strings: " + "https://github.com/pydata/xarray/issues/1647".format(name) + ) - string_encoding = encoding.pop('_Encoding', 'utf-8') - safe_setitem(attrs, '_Encoding', string_encoding, name=name) + string_encoding = encoding.pop("_Encoding", "utf-8") + safe_setitem(attrs, "_Encoding", string_encoding, name=name) # TODO: figure out how to handle this in a lazy way with dask data = encode_string_array(data, string_encoding) @@ -64,22 +69,22 @@ def encode(self, variable, name=None): def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if '_Encoding' in attrs: - string_encoding = pop_to(attrs, encoding, '_Encoding') + if "_Encoding" in attrs: + string_encoding = pop_to(attrs, encoding, "_Encoding") func = partial(decode_bytes_array, encoding=string_encoding) data = lazy_elemwise_func(data, func, np.dtype(object)) return Variable(dims, data, attrs, encoding) -def decode_bytes_array(bytes_array, encoding='utf-8'): +def decode_bytes_array(bytes_array, encoding="utf-8"): # This is faster than using np.char.decode() or np.vectorize() bytes_array = np.asarray(bytes_array) decoded = [x.decode(encoding) for x in bytes_array.ravel()] return np.array(decoded, dtype=object).reshape(bytes_array.shape) -def encode_string_array(string_array, encoding='utf-8'): +def encode_string_array(string_array, encoding="utf-8"): string_array = np.asarray(string_array) encoded = [x.encode(encoding) for x in string_array.ravel()] return np.array(encoded, dtype=bytes).reshape(string_array.shape) @@ -101,20 +106,20 @@ def encode(self, variable, name=None): variable = ensure_fixed_length_bytes(variable) dims, data, attrs, encoding = unpack_for_encoding(variable) - if data.dtype.kind == 'S' and encoding.get('dtype') is not str: + if data.dtype.kind == "S" and encoding.get("dtype") is not str: data = bytes_to_char(data) - if 'char_dim_name' in encoding.keys(): - char_dim_name = encoding.pop('char_dim_name') + if "char_dim_name" in encoding.keys(): + char_dim_name = encoding.pop("char_dim_name") else: - char_dim_name = 'string%s' % data.shape[-1] + char_dim_name = "string%s" % data.shape[-1] dims = dims + (char_dim_name,) return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if data.dtype == 'S1' and dims: - encoding['char_dim_name'] = dims[-1] + if data.dtype == "S1" and dims: + encoding["char_dim_name"] = dims[-1] dims = dims[:-1] data = char_to_bytes(data) return Variable(dims, data, attrs, encoding) @@ -122,15 +127,19 @@ def decode(self, variable, name=None): def bytes_to_char(arr): """Convert numpy/dask arrays from fixed width bytes to characters.""" - if arr.dtype.kind != 'S': - raise ValueError('argument must have a fixed-width bytes dtype') + if arr.dtype.kind != "S": + raise ValueError("argument must have a fixed-width bytes dtype") if isinstance(arr, dask_array_type): import dask.array as da - return da.map_blocks(_numpy_bytes_to_char, arr, - dtype='S1', - chunks=arr.chunks + ((arr.dtype.itemsize,)), - new_axis=[arr.ndim]) + + return da.map_blocks( + _numpy_bytes_to_char, + arr, + dtype="S1", + chunks=arr.chunks + ((arr.dtype.itemsize,)), + new_axis=[arr.ndim], + ) else: return _numpy_bytes_to_char(arr) @@ -139,13 +148,13 @@ def _numpy_bytes_to_char(arr): """Like netCDF4.stringtochar, but faster and more flexible. """ # ensure the array is contiguous - arr = np.array(arr, copy=False, order='C', dtype=np.string_) - return arr.reshape(arr.shape + (1,)).view('S1') + arr = np.array(arr, copy=False, order="C", dtype=np.string_) + return arr.reshape(arr.shape + (1,)).view("S1") def char_to_bytes(arr): """Convert numpy/dask arrays from characters to fixed width bytes.""" - if arr.dtype != 'S1': + if arr.dtype != "S1": raise ValueError("argument must have dtype='S1'") if not arr.ndim: @@ -162,15 +171,19 @@ def char_to_bytes(arr): import dask.array as da if len(arr.chunks[-1]) > 1: - raise ValueError('cannot stacked dask character array with ' - 'multiple chunks in the last dimension: {}' - .format(arr)) - - dtype = np.dtype('S' + str(arr.shape[-1])) - return da.map_blocks(_numpy_char_to_bytes, arr, - dtype=dtype, - chunks=arr.chunks[:-1], - drop_axis=[arr.ndim - 1]) + raise ValueError( + "cannot stacked dask character array with " + "multiple chunks in the last dimension: {}".format(arr) + ) + + dtype = np.dtype("S" + str(arr.shape[-1])) + return da.map_blocks( + _numpy_char_to_bytes, + arr, + dtype=dtype, + chunks=arr.chunks[:-1], + drop_axis=[arr.ndim - 1], + ) else: return StackedBytesArray(arr) @@ -179,8 +192,8 @@ def _numpy_char_to_bytes(arr): """Like netCDF4.chartostring, but faster and more flexible. """ # based on: http://stackoverflow.com/a/10984878/809705 - arr = np.array(arr, copy=False, order='C') - dtype = 'S' + str(arr.shape[-1]) + arr = np.array(arr, copy=False, order="C") + dtype = "S" + str(arr.shape[-1]) return arr.view(dtype).reshape(arr.shape[:-1]) @@ -200,25 +213,26 @@ def __init__(self, array): array : array-like Original array of values to wrap. """ - if array.dtype != 'S1': + if array.dtype != "S1": raise ValueError( - "can only use StackedBytesArray if argument has dtype='S1'") + "can only use StackedBytesArray if argument has dtype='S1'" + ) self.array = indexing.as_indexable(array) @property def dtype(self): - return np.dtype('S' + str(self.array.shape[-1])) + return np.dtype("S" + str(self.array.shape[-1])) @property def shape(self): return self.array.shape[:-1] def __repr__(self): - return ('%s(%r)' % (type(self).__name__, self.array)) + return "%s(%r)" % (type(self).__name__, self.array) def __getitem__(self, key): # require slicing the last dimension completely key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) if key.tuple[-1] != slice(None): - raise IndexError('too many indices') + raise IndexError("too many indices") return _numpy_char_to_bytes(self.array[key]) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 4930a77d022..7b5a7c56a53 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -12,8 +12,14 @@ from ..core.formatting import first_n_items, format_timestamp, last_item from ..core.variable import Variable from .variables import ( - SerializationWarning, VariableCoder, lazy_elemwise_func, pop_to, - safe_setitem, unpack_for_decoding, unpack_for_encoding) + SerializationWarning, + VariableCoder, + lazy_elemwise_func, + pop_to, + safe_setitem, + unpack_for_decoding, + unpack_for_encoding, +) try: from pandas.errors import OutOfBoundsDatetime @@ -23,24 +29,27 @@ # standard calendars recognized by cftime -_STANDARD_CALENDARS = {'standard', 'gregorian', 'proleptic_gregorian'} +_STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} -_NS_PER_TIME_DELTA = {'us': int(1e3), - 'ms': int(1e6), - 's': int(1e9), - 'm': int(1e9) * 60, - 'h': int(1e9) * 60 * 60, - 'D': int(1e9) * 60 * 60 * 24} +_NS_PER_TIME_DELTA = { + "us": int(1e3), + "ms": int(1e6), + "s": int(1e9), + "m": int(1e9) * 60, + "h": int(1e9) * 60 * 60, + "D": int(1e9) * 60 * 60 * 24, +} -TIME_UNITS = frozenset(['days', 'hours', 'minutes', 'seconds', - 'milliseconds', 'microseconds']) +TIME_UNITS = frozenset( + ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"] +) def _import_cftime(): - ''' + """ helper function handle the transition to netcdftime/cftime as a stand-alone package - ''' + """ try: import cftime except ImportError: @@ -57,26 +66,34 @@ def _require_standalone_cftime(): try: import cftime # noqa: F401 except ImportError: - raise ImportError('Decoding times with non-standard calendars ' - 'or outside the pandas.Timestamp-valid range ' - 'requires the standalone cftime package.') + raise ImportError( + "Decoding times with non-standard calendars " + "or outside the pandas.Timestamp-valid range " + "requires the standalone cftime package." + ) def _netcdf_to_numpy_timeunit(units): units = units.lower() - if not units.endswith('s'): - units = '%ss' % units - return {'microseconds': 'us', 'milliseconds': 'ms', 'seconds': 's', - 'minutes': 'm', 'hours': 'h', 'days': 'D'}[units] + if not units.endswith("s"): + units = "%ss" % units + return { + "microseconds": "us", + "milliseconds": "ms", + "seconds": "s", + "minutes": "m", + "hours": "h", + "days": "D", + }[units] def _unpack_netcdf_time_units(units): # CF datetime units follow the format: "UNIT since DATE" # this parses out the unit and date allowing for extraneous # whitespace. - matches = re.match('(.+) since (.+)', units) + matches = re.match("(.+) since (.+)", units) if not matches: - raise ValueError('invalid time units: %s' % units) + raise ValueError("invalid time units: %s" % units) delta_units, ref_date = [s.strip() for s in matches.groups()] return delta_units, ref_date @@ -85,23 +102,24 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime): # Verify that at least the first and last date can be decoded # successfully. Otherwise, tracebacks end up swallowed by # Dataset.__repr__ when users try to view their lazily decoded array. - values = indexing.ImplicitToExplicitIndexingAdapter( - indexing.as_indexable(data)) - example_value = np.concatenate([first_n_items(values, 1) or [0], - last_item(values) or [0]]) + values = indexing.ImplicitToExplicitIndexingAdapter(indexing.as_indexable(data)) + example_value = np.concatenate( + [first_n_items(values, 1) or [0], last_item(values) or [0]] + ) try: - result = decode_cf_datetime(example_value, units, calendar, - use_cftime) + result = decode_cf_datetime(example_value, units, calendar, use_cftime) except Exception: - calendar_msg = ('the default calendar' if calendar is None - else 'calendar %r' % calendar) - msg = ('unable to decode time units %r with %s. Try ' - 'opening your dataset with decode_times=False.' - % (units, calendar_msg)) + calendar_msg = ( + "the default calendar" if calendar is None else "calendar %r" % calendar + ) + msg = ( + "unable to decode time units %r with %s. Try " + "opening your dataset with decode_times=False." % (units, calendar_msg) + ) raise ValueError(msg) else: - dtype = getattr(result, 'dtype', np.dtype('object')) + dtype = getattr(result, "dtype", np.dtype("object")) return dtype @@ -109,9 +127,10 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime): def _decode_datetime_with_cftime(num_dates, units, calendar): cftime = _import_cftime() - if cftime.__name__ == 'cftime': - return np.asarray(cftime.num2date(num_dates, units, calendar, - only_use_cftime_datetimes=True)) + if cftime.__name__ == "cftime": + return np.asarray( + cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True) + ) else: # Must be using num2date from an old version of netCDF4 which # does not have the only_use_cftime_datetimes option. @@ -121,8 +140,9 @@ def _decode_datetime_with_cftime(num_dates, units, calendar): def _decode_datetime_with_pandas(flat_num_dates, units, calendar): if calendar not in _STANDARD_CALENDARS: raise OutOfBoundsDatetime( - 'Cannot decode times from a non-standard calendar, {!r}, using ' - 'pandas.'.format(calendar)) + "Cannot decode times from a non-standard calendar, {!r}, using " + "pandas.".format(calendar) + ) delta, ref_date = _unpack_netcdf_time_units(units) delta = _netcdf_to_numpy_timeunit(delta) @@ -137,18 +157,18 @@ def _decode_datetime_with_pandas(flat_num_dates, units, calendar): # these lines check if the the lowest or the highest value in dates # cause an OutOfBoundsDatetime (Overflow) error with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'invalid value encountered', - RuntimeWarning) + warnings.filterwarnings("ignore", "invalid value encountered", RuntimeWarning) pd.to_timedelta(flat_num_dates.min(), delta) + ref_date pd.to_timedelta(flat_num_dates.max(), delta) + ref_date # Cast input dates to integers of nanoseconds because `pd.to_datetime` # works much faster when dealing with integers # make _NS_PER_TIME_DELTA an array to ensure type upcasting - flat_num_dates_ns_int = (flat_num_dates.astype(np.float64) * - _NS_PER_TIME_DELTA[delta]).astype(np.int64) + flat_num_dates_ns_int = ( + flat_num_dates.astype(np.float64) * _NS_PER_TIME_DELTA[delta] + ).astype(np.int64) - return (pd.to_timedelta(flat_num_dates_ns_int, 'ns') + ref_date).values + return (pd.to_timedelta(flat_num_dates_ns_int, "ns") + ref_date).values def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): @@ -169,30 +189,36 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): num_dates = np.asarray(num_dates) flat_num_dates = num_dates.ravel() if calendar is None: - calendar = 'standard' + calendar = "standard" if use_cftime is None: try: - dates = _decode_datetime_with_pandas(flat_num_dates, units, - calendar) + dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) except (OutOfBoundsDatetime, OverflowError): dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar) + flat_num_dates.astype(np.float), units, calendar + ) - if (dates[np.nanargmin(num_dates)].year < 1678 or - dates[np.nanargmax(num_dates)].year >= 2262): + if ( + dates[np.nanargmin(num_dates)].year < 1678 + or dates[np.nanargmax(num_dates)].year >= 2262 + ): if calendar in _STANDARD_CALENDARS: warnings.warn( - 'Unable to decode time axis into full ' - 'numpy.datetime64 objects, continuing using ' - 'cftime.datetime objects instead, reason: dates out ' - 'of range', SerializationWarning, stacklevel=3) + "Unable to decode time axis into full " + "numpy.datetime64 objects, continuing using " + "cftime.datetime objects instead, reason: dates out " + "of range", + SerializationWarning, + stacklevel=3, + ) else: if calendar in _STANDARD_CALENDARS: dates = cftime_to_nptime(dates) elif use_cftime: dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar) + flat_num_dates.astype(np.float), units, calendar + ) else: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) @@ -200,20 +226,20 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): def to_timedelta_unboxed(value, **kwargs): - if LooseVersion(pd.__version__) < '0.25.0': + if LooseVersion(pd.__version__) < "0.25.0": result = pd.to_timedelta(value, **kwargs, box=False) else: result = pd.to_timedelta(value, **kwargs).to_numpy() - assert result.dtype == 'timedelta64[ns]' + assert result.dtype == "timedelta64[ns]" return result def to_datetime_unboxed(value, **kwargs): - if LooseVersion(pd.__version__) < '0.25.0': + if LooseVersion(pd.__version__) < "0.25.0": result = pd.to_datetime(value, **kwargs, box=False) else: result = pd.to_datetime(value, **kwargs).to_numpy() - assert result.dtype == 'datetime64[ns]' + assert result.dtype == "datetime64[ns]" return result @@ -228,19 +254,19 @@ def decode_cf_timedelta(num_timedeltas, units): def _infer_time_units_from_diff(unique_timedeltas): - for time_unit in ['days', 'hours', 'minutes', 'seconds']: + for time_unit in ["days", "hours", "minutes", "seconds"]: delta_ns = _NS_PER_TIME_DELTA[_netcdf_to_numpy_timeunit(time_unit)] - unit_delta = np.timedelta64(delta_ns, 'ns') + unit_delta = np.timedelta64(delta_ns, "ns") diffs = unique_timedeltas / unit_delta if np.all(diffs == diffs.astype(int)): return time_unit - return 'seconds' + return "seconds" def infer_calendar_name(dates): """Given an array of datetimes, infer the CF calendar name""" - if np.asarray(dates).dtype == 'datetime64[ns]': - return 'proleptic_gregorian' + if np.asarray(dates).dtype == "datetime64[ns]": + return "proleptic_gregorian" else: return np.asarray(dates).ravel()[0].calendar @@ -252,30 +278,36 @@ def infer_datetime_units(dates): unique time deltas in `dates`) """ dates = np.asarray(dates).ravel() - if np.asarray(dates).dtype == 'datetime64[ns]': + if np.asarray(dates).dtype == "datetime64[ns]": dates = to_datetime_unboxed(dates) dates = dates[pd.notnull(dates)] - reference_date = dates[0] if len(dates) > 0 else '1970-01-01' + reference_date = dates[0] if len(dates) > 0 else "1970-01-01" reference_date = pd.Timestamp(reference_date) else: - reference_date = dates[0] if len(dates) > 0 else '1970-01-01' + reference_date = dates[0] if len(dates) > 0 else "1970-01-01" reference_date = format_cftime_datetime(reference_date) unique_timedeltas = np.unique(np.diff(dates)) - if unique_timedeltas.dtype == np.dtype('O'): + if unique_timedeltas.dtype == np.dtype("O"): # Convert to np.timedelta64 objects using pandas to work around a # NumPy casting bug: https://github.com/numpy/numpy/issues/11096 unique_timedeltas = to_timedelta_unboxed(unique_timedeltas) units = _infer_time_units_from_diff(unique_timedeltas) - return '%s since %s' % (units, reference_date) + return "%s since %s" % (units, reference_date) def format_cftime_datetime(date): """Converts a cftime.datetime object to a string with the format: YYYY-MM-DD HH:MM:SS.UUUUUU """ - return '{:04d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}.{:06d}'.format( - date.year, date.month, date.day, date.hour, date.minute, date.second, - date.microsecond) + return "{:04d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}.{:06d}".format( + date.year, + date.month, + date.day, + date.hour, + date.minute, + date.second, + date.microsecond, + ) def infer_timedelta_units(deltas): @@ -293,18 +325,21 @@ def cftime_to_nptime(times): """Given an array of cftime.datetime objects, return an array of numpy.datetime64 objects of the same size""" times = np.asarray(times) - new = np.empty(times.shape, dtype='M8[ns]') + new = np.empty(times.shape, dtype="M8[ns]") for i, t in np.ndenumerate(times): try: # Use pandas.Timestamp in place of datetime.datetime, because # NumPy casts it safely it np.datetime64[ns] for dates outside # 1678 to 2262 (this is not currently the case for # datetime.datetime). - dt = pd.Timestamp(t.year, t.month, t.day, t.hour, t.minute, - t.second, t.microsecond) + dt = pd.Timestamp( + t.year, t.month, t.day, t.hour, t.minute, t.second, t.microsecond + ) except ValueError as e: - raise ValueError('Cannot convert date {} to a date in the ' - 'standard calendar. Reason: {}.'.format(t, e)) + raise ValueError( + "Cannot convert date {} to a date in the " + "standard calendar. Reason: {}.".format(t, e) + ) new[i] = np.datetime64(dt) return new @@ -312,7 +347,7 @@ def cftime_to_nptime(times): def _cleanup_netcdf_time_units(units): delta, ref_date = _unpack_netcdf_time_units(units) try: - units = '%s since %s' % (delta, format_timestamp(ref_date)) + units = "%s since %s" % (delta, format_timestamp(ref_date)) except OutOfBoundsDatetime: # don't worry about reifying the units if they're out of bounds pass @@ -329,7 +364,7 @@ def _encode_datetime_with_cftime(dates, units, calendar): if np.issubdtype(dates.dtype, np.datetime64): # numpy's broken datetime conversion only works for us precision - dates = dates.astype('M8[us]').astype(datetime) + dates = dates.astype("M8[us]").astype(datetime) def encode_datetime(d): return np.nan if d is None else cftime.date2num(d, units, calendar) @@ -366,13 +401,13 @@ def encode_cf_datetime(dates, units=None, calendar=None): delta, ref_date = _unpack_netcdf_time_units(units) try: - if calendar not in _STANDARD_CALENDARS or dates.dtype.kind == 'O': + if calendar not in _STANDARD_CALENDARS or dates.dtype.kind == "O": # parse with cftime instead raise OutOfBoundsDatetime - assert dates.dtype == 'datetime64[ns]' + assert dates.dtype == "datetime64[ns]" delta_units = _netcdf_to_numpy_timeunit(delta) - time_delta = np.timedelta64(1, delta_units).astype('timedelta64[ns]') + time_delta = np.timedelta64(1, delta_units).astype("timedelta64[ns]") ref_date = pd.Timestamp(ref_date) # If the ref_date Timestamp is timezone-aware, convert to UTC and @@ -410,52 +445,52 @@ def __init__(self, use_cftime=None): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) - if (np.issubdtype(data.dtype, np.datetime64) or - contains_cftime_datetimes(variable)): + if np.issubdtype(data.dtype, np.datetime64) or contains_cftime_datetimes( + variable + ): (data, units, calendar) = encode_cf_datetime( - data, - encoding.pop('units', None), - encoding.pop('calendar', None)) - safe_setitem(attrs, 'units', units, name=name) - safe_setitem(attrs, 'calendar', calendar, name=name) + data, encoding.pop("units", None), encoding.pop("calendar", None) + ) + safe_setitem(attrs, "units", units, name=name) + safe_setitem(attrs, "calendar", calendar, name=name) return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if 'units' in attrs and 'since' in attrs['units']: - units = pop_to(attrs, encoding, 'units') - calendar = pop_to(attrs, encoding, 'calendar') - dtype = _decode_cf_datetime_dtype(data, units, calendar, - self.use_cftime) + if "units" in attrs and "since" in attrs["units"]: + units = pop_to(attrs, encoding, "units") + calendar = pop_to(attrs, encoding, "calendar") + dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime) transform = partial( - decode_cf_datetime, units=units, calendar=calendar, - use_cftime=self.use_cftime) + decode_cf_datetime, + units=units, + calendar=calendar, + use_cftime=self.use_cftime, + ) data = lazy_elemwise_func(data, transform, dtype) return Variable(dims, data, attrs, encoding) class CFTimedeltaCoder(VariableCoder): - def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) if np.issubdtype(data.dtype, np.timedelta64): - data, units = encode_cf_timedelta( - data, encoding.pop('units', None)) - safe_setitem(attrs, 'units', units, name=name) + data, units = encode_cf_timedelta(data, encoding.pop("units", None)) + safe_setitem(attrs, "units", units, name=name) return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if 'units' in attrs and attrs['units'] in TIME_UNITS: - units = pop_to(attrs, encoding, 'units') + if "units" in attrs and attrs["units"] in TIME_UNITS: + units = pop_to(attrs, encoding, "units") transform = partial(decode_cf_timedelta, units=units) - dtype = np.dtype('timedelta64[ns]') + dtype = np.dtype("timedelta64[ns]") data = lazy_elemwise_func(data, transform, dtype=dtype) return Variable(dims, data, attrs, encoding) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index cc173f78b92..f54ae7867d8 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -69,8 +69,12 @@ def __array__(self, dtype=None): return self.func(self.array) def __repr__(self): - return ("%s(%r, func=%r, dtype=%r)" % - (type(self).__name__, self.array, self.func, self.dtype)) + return "%s(%r, func=%r, dtype=%r)" % ( + type(self).__name__, + self.array, + self.func, + self.dtype, + ) def lazy_elemwise_func(array, func, dtype): @@ -105,12 +109,13 @@ def unpack_for_decoding(var): def safe_setitem(dest, key, value, name=None): if key in dest: - var_str = ' on variable {!r}'.format(name) if name else '' + var_str = " on variable {!r}".format(name) if name else "" raise ValueError( - 'failed to prevent overwriting existing key {} in attrs{}. ' - 'This is probably an encoding field used by xarray to describe ' - 'how a variable is serialized. To proceed, remove this key from ' - "the variable's attributes manually.".format(key, var_str)) + "failed to prevent overwriting existing key {} in attrs{}. " + "This is probably an encoding field used by xarray to describe " + "how a variable is serialized. To proceed, remove this key from " + "the variable's attributes manually.".format(key, var_str) + ) dest[key] = value @@ -127,10 +132,7 @@ def pop_to(source, dest, key, name=None): def _apply_mask( - data: np.ndarray, - encoded_fill_values: list, - decoded_fill_value: Any, - dtype: Any, + data: np.ndarray, encoded_fill_values: list, decoded_fill_value: Any, dtype: Any ) -> np.ndarray: """Mask all matching values in a NumPy arrays.""" data = np.asarray(data, dtype=dtype) @@ -146,21 +148,22 @@ class CFMaskCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) - fv = encoding.get('_FillValue') - mv = encoding.get('missing_value') + fv = encoding.get("_FillValue") + mv = encoding.get("missing_value") if fv is not None and mv is not None and not equivalent(fv, mv): - raise ValueError("Variable {!r} has multiple fill values {}. " - "Cannot encode data. " - .format(name, [fv, mv])) + raise ValueError( + "Variable {!r} has multiple fill values {}. " + "Cannot encode data. ".format(name, [fv, mv]) + ) if fv is not None: - fill_value = pop_to(encoding, attrs, '_FillValue', name=name) + fill_value = pop_to(encoding, attrs, "_FillValue", name=name) if not pd.isnull(fill_value): data = duck_array_ops.fillna(data, fill_value) if mv is not None: - fill_value = pop_to(encoding, attrs, 'missing_value', name=name) + fill_value = pop_to(encoding, attrs, "missing_value", name=name) if not pd.isnull(fill_value) and fv is None: data = duck_array_ops.fillna(data, fill_value) @@ -169,26 +172,35 @@ def encode(self, variable, name=None): def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - raw_fill_values = [pop_to(attrs, encoding, attr, name=name) - for attr in ('missing_value', '_FillValue')] + raw_fill_values = [ + pop_to(attrs, encoding, attr, name=name) + for attr in ("missing_value", "_FillValue") + ] if raw_fill_values: - encoded_fill_values = {fv for option in raw_fill_values - for fv in np.ravel(option) - if not pd.isnull(fv)} + encoded_fill_values = { + fv + for option in raw_fill_values + for fv in np.ravel(option) + if not pd.isnull(fv) + } if len(encoded_fill_values) > 1: - warnings.warn("variable {!r} has multiple fill values {}, " - "decoding all values to NaN." - .format(name, encoded_fill_values), - SerializationWarning, stacklevel=3) + warnings.warn( + "variable {!r} has multiple fill values {}, " + "decoding all values to NaN.".format(name, encoded_fill_values), + SerializationWarning, + stacklevel=3, + ) dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype) if encoded_fill_values: - transform = partial(_apply_mask, - encoded_fill_values=encoded_fill_values, - decoded_fill_value=decoded_fill_value, - dtype=dtype) + transform = partial( + _apply_mask, + encoded_fill_values=encoded_fill_values, + decoded_fill_value=decoded_fill_value, + dtype=dtype, + ) data = lazy_elemwise_func(data, transform, dtype) return Variable(dims, data, attrs, encoding) @@ -232,34 +244,35 @@ class CFScaleOffsetCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) - if 'scale_factor' in encoding or 'add_offset' in encoding: - dtype = _choose_float_dtype(data.dtype, 'add_offset' in encoding) + if "scale_factor" in encoding or "add_offset" in encoding: + dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding) data = data.astype(dtype=dtype, copy=True) - if 'add_offset' in encoding: - data -= pop_to(encoding, attrs, 'add_offset', name=name) - if 'scale_factor' in encoding: - data /= pop_to(encoding, attrs, 'scale_factor', name=name) + if "add_offset" in encoding: + data -= pop_to(encoding, attrs, "add_offset", name=name) + if "scale_factor" in encoding: + data /= pop_to(encoding, attrs, "scale_factor", name=name) return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if 'scale_factor' in attrs or 'add_offset' in attrs: - scale_factor = pop_to(attrs, encoding, 'scale_factor', name=name) - add_offset = pop_to(attrs, encoding, 'add_offset', name=name) - dtype = _choose_float_dtype(data.dtype, 'add_offset' in attrs) - transform = partial(_scale_offset_decoding, - scale_factor=scale_factor, - add_offset=add_offset, - dtype=dtype) + if "scale_factor" in attrs or "add_offset" in attrs: + scale_factor = pop_to(attrs, encoding, "scale_factor", name=name) + add_offset = pop_to(attrs, encoding, "add_offset", name=name) + dtype = _choose_float_dtype(data.dtype, "add_offset" in attrs) + transform = partial( + _scale_offset_decoding, + scale_factor=scale_factor, + add_offset=add_offset, + dtype=dtype, + ) data = lazy_elemwise_func(data, transform, dtype) return Variable(dims, data, attrs, encoding) class UnsignedIntegerCoder(VariableCoder): - def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) @@ -267,12 +280,12 @@ def encode(self, variable, name=None): # https://www.unidata.ucar.edu/software/netcdf/docs/BestPractices.html # "_Unsigned = "true" to indicate that # integer data should be treated as unsigned" - if encoding.get('_Unsigned', 'false') == 'true': - pop_to(encoding, attrs, '_Unsigned') - signed_dtype = np.dtype('i%s' % data.dtype.itemsize) - if '_FillValue' in attrs: - new_fill = signed_dtype.type(attrs['_FillValue']) - attrs['_FillValue'] = new_fill + if encoding.get("_Unsigned", "false") == "true": + pop_to(encoding, attrs, "_Unsigned") + signed_dtype = np.dtype("i%s" % data.dtype.itemsize) + if "_FillValue" in attrs: + new_fill = signed_dtype.type(attrs["_FillValue"]) + attrs["_FillValue"] = new_fill data = duck_array_ops.around(data).astype(signed_dtype) return Variable(dims, data, attrs, encoding) @@ -280,20 +293,23 @@ def encode(self, variable, name=None): def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if '_Unsigned' in attrs: - unsigned = pop_to(attrs, encoding, '_Unsigned') + if "_Unsigned" in attrs: + unsigned = pop_to(attrs, encoding, "_Unsigned") - if data.dtype.kind == 'i': - if unsigned == 'true': - unsigned_dtype = np.dtype('u%s' % data.dtype.itemsize) + if data.dtype.kind == "i": + if unsigned == "true": + unsigned_dtype = np.dtype("u%s" % data.dtype.itemsize) transform = partial(np.asarray, dtype=unsigned_dtype) data = lazy_elemwise_func(data, transform, unsigned_dtype) - if '_FillValue' in attrs: - new_fill = unsigned_dtype.type(attrs['_FillValue']) - attrs['_FillValue'] = new_fill + if "_FillValue" in attrs: + new_fill = unsigned_dtype.type(attrs["_FillValue"]) + attrs["_FillValue"] = new_fill else: - warnings.warn("variable %r has _Unsigned attribute but is not " - "of integer type. Ignoring attribute." % name, - SerializationWarning, stacklevel=3) + warnings.warn( + "variable %r has _Unsigned attribute but is not " + "of integer type. Ignoring attribute." % name, + SerializationWarning, + stacklevel=3, + ) return Variable(dims, data, attrs, encoding) diff --git a/xarray/conventions.py b/xarray/conventions.py index 616e557efcd..c15e5c40e73 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -65,7 +65,7 @@ def __init__(self, array): @property def dtype(self): - return np.dtype('bool') + return np.dtype("bool") def __getitem__(self, key): return np.asarray(self.array[key], dtype=self.dtype) @@ -76,19 +76,23 @@ def _var_as_tuple(var): def maybe_encode_nonstring_dtype(var, name=None): - if ('dtype' in var.encoding and - var.encoding['dtype'] not in ('S1', str)): + if "dtype" in var.encoding and var.encoding["dtype"] not in ("S1", str): dims, data, attrs, encoding = _var_as_tuple(var) - dtype = np.dtype(encoding.pop('dtype')) + dtype = np.dtype(encoding.pop("dtype")) if dtype != var.dtype: if np.issubdtype(dtype, np.integer): - if (np.issubdtype(var.dtype, np.floating) and - '_FillValue' not in var.attrs and - 'missing_value' not in var.attrs): - warnings.warn('saving variable %s with floating ' - 'point data as an integer dtype without ' - 'any _FillValue to use for NaNs' % name, - SerializationWarning, stacklevel=10) + if ( + np.issubdtype(var.dtype, np.floating) + and "_FillValue" not in var.attrs + and "missing_value" not in var.attrs + ): + warnings.warn( + "saving variable %s with floating " + "point data as an integer dtype without " + "any _FillValue to use for NaNs" % name, + SerializationWarning, + stacklevel=10, + ) data = duck_array_ops.around(data)[...] data = data.astype(dtype=dtype) var = Variable(dims, data, attrs, encoding) @@ -97,19 +101,24 @@ def maybe_encode_nonstring_dtype(var, name=None): def maybe_default_fill_value(var): # make NaN the fill value for float types: - if ('_FillValue' not in var.attrs and - '_FillValue' not in var.encoding and - np.issubdtype(var.dtype, np.floating)): - var.attrs['_FillValue'] = var.dtype.type(np.nan) + if ( + "_FillValue" not in var.attrs + and "_FillValue" not in var.encoding + and np.issubdtype(var.dtype, np.floating) + ): + var.attrs["_FillValue"] = var.dtype.type(np.nan) return var def maybe_encode_bools(var): - if ((var.dtype == np.bool) and - ('dtype' not in var.encoding) and ('dtype' not in var.attrs)): + if ( + (var.dtype == np.bool) + and ("dtype" not in var.encoding) + and ("dtype" not in var.attrs) + ): dims, data, attrs, encoding = _var_as_tuple(var) - attrs['dtype'] = 'bool' - data = data.astype(dtype='i1', copy=True) + attrs["dtype"] = "bool" + data = data.astype(dtype="i1", copy=True) var = Variable(dims, data, attrs, encoding) return var @@ -118,8 +127,8 @@ def _infer_dtype(array, name=None): """Given an object array with no missing values, infer its dtype from its first element """ - if array.dtype.kind != 'O': - raise TypeError('infer_type must be called on a dtype=object array') + if array.dtype.kind != "O": + raise TypeError("infer_type must be called on a dtype=object array") if array.size == 0: return np.dtype(float) @@ -129,23 +138,24 @@ def _infer_dtype(array, name=None): return strings.create_vlen_dtype(type(element)) dtype = np.array(element).dtype - if dtype.kind != 'O': + if dtype.kind != "O": return dtype - raise ValueError('unable to infer dtype on variable {!r}; xarray ' - 'cannot serialize arbitrary Python objects' - .format(name)) + raise ValueError( + "unable to infer dtype on variable {!r}; xarray " + "cannot serialize arbitrary Python objects".format(name) + ) def ensure_not_multiindex(var, name=None): - if (isinstance(var, IndexVariable) and - isinstance(var.to_index(), pd.MultiIndex)): + if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex): raise NotImplementedError( - 'variable {!r} is a MultiIndex, which cannot yet be ' - 'serialized to netCDF files ' - '(https://github.com/pydata/xarray/issues/1077). Use ' - 'reset_index() to convert MultiIndex levels into coordinate ' - 'variables instead.'.format(name)) + "variable {!r} is a MultiIndex, which cannot yet be " + "serialized to netCDF files " + "(https://github.com/pydata/xarray/issues/1077). Use " + "reset_index() to convert MultiIndex levels into coordinate " + "variables instead.".format(name) + ) def _copy_with_dtype(data, dtype): @@ -161,17 +171,18 @@ def _copy_with_dtype(data, dtype): def ensure_dtype_not_object(var, name=None): # TODO: move this from conventions to backends? (it's not CF related) - if var.dtype.kind == 'O': + if var.dtype.kind == "O": dims, data, attrs, encoding = _var_as_tuple(var) if isinstance(data, dask_array_type): warnings.warn( - 'variable {} has data in the form of a dask array with ' - 'dtype=object, which means it is being loaded into memory ' - 'to determine a data type that can be safely stored on disk. ' - 'To avoid this, coerce this variable to a fixed-size dtype ' - 'with astype() before saving it.'.format(name), - SerializationWarning) + "variable {} has data in the form of a dask array with " + "dtype=object, which means it is being loaded into memory " + "to determine a data type that can be safely stored on disk. " + "To avoid this, coerce this variable to a fixed-size dtype " + "with astype() before saving it.".format(name), + SerializationWarning, + ) data = data.compute() missing = pd.isnull(data) @@ -184,9 +195,9 @@ def ensure_dtype_not_object(var, name=None): # formats, we so can't set a fill_value. Unfortunately, this means # we can't distinguish between missing values and empty strings. if strings.is_bytes_dtype(inferred_dtype): - fill_value = b'' + fill_value = b"" elif strings.is_unicode_dtype(inferred_dtype): - fill_value = '' + fill_value = "" else: # insist on using float for numeric values if not np.issubdtype(inferred_dtype, np.floating): @@ -198,7 +209,7 @@ def ensure_dtype_not_object(var, name=None): else: data = _copy_with_dtype(data, dtype=_infer_dtype(data, name)) - assert data.dtype.kind != 'O' or data.dtype.metadata + assert data.dtype.kind != "O" or data.dtype.metadata var = Variable(dims, data, attrs, encoding) return var @@ -225,11 +236,13 @@ def encode_cf_variable(var, needs_copy=True, name=None): """ ensure_not_multiindex(var, name=name) - for coder in [times.CFDatetimeCoder(), - times.CFTimedeltaCoder(), - variables.CFScaleOffsetCoder(), - variables.CFMaskCoder(), - variables.UnsignedIntegerCoder()]: + for coder in [ + times.CFDatetimeCoder(), + times.CFTimedeltaCoder(), + variables.CFScaleOffsetCoder(), + variables.CFMaskCoder(), + variables.UnsignedIntegerCoder(), + ]: var = coder.encode(var, name=name) # TODO(shoyer): convert all of these to use coders, too: @@ -240,9 +253,16 @@ def encode_cf_variable(var, needs_copy=True, name=None): return var -def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True, - decode_times=True, decode_endianness=True, - stack_char_dim=True, use_cftime=None): +def decode_cf_variable( + name, + var, + concat_characters=True, + mask_and_scale=True, + decode_times=True, + decode_endianness=True, + stack_char_dim=True, + use_cftime=None, +): """ Decodes a variable which may hold CF encoded information. @@ -297,18 +317,21 @@ def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True, var = strings.EncodedStringCoder().decode(var) if mask_and_scale: - for coder in [variables.UnsignedIntegerCoder(), - variables.CFMaskCoder(), - variables.CFScaleOffsetCoder()]: + for coder in [ + variables.UnsignedIntegerCoder(), + variables.CFMaskCoder(), + variables.CFScaleOffsetCoder(), + ]: var = coder.decode(var, name=name) if decode_times: - for coder in [times.CFTimedeltaCoder(), - times.CFDatetimeCoder(use_cftime=use_cftime)]: + for coder in [ + times.CFTimedeltaCoder(), + times.CFDatetimeCoder(use_cftime=use_cftime), + ]: var = coder.decode(var, name=name) - dimensions, data, attributes, encoding = ( - variables.unpack_for_decoding(var)) + dimensions, data, attributes, encoding = variables.unpack_for_decoding(var) # TODO(shoyer): convert everything below to use coders if decode_endianness and not data.dtype.isnative: @@ -316,10 +339,10 @@ def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True, data = NativeEndiannessArray(data) original_dtype = data.dtype - encoding.setdefault('dtype', original_dtype) + encoding.setdefault("dtype", original_dtype) - if 'dtype' in attributes and attributes['dtype'] == 'bool': - del attributes['dtype'] + if "dtype" in attributes and attributes["dtype"] == "bool": + del attributes["dtype"] data = BoolTypeArray(data) if not isinstance(data, dask_array_type): @@ -347,13 +370,13 @@ def _update_bounds_attributes(variables): # For all time variables with bounds for v in variables.values(): attrs = v.attrs - has_date_units = 'units' in attrs and 'since' in attrs['units'] - if has_date_units and 'bounds' in attrs: - if attrs['bounds'] in variables: - bounds_attrs = variables[attrs['bounds']].attrs - bounds_attrs.setdefault('units', attrs['units']) - if 'calendar' in attrs: - bounds_attrs.setdefault('calendar', attrs['calendar']) + has_date_units = "units" in attrs and "since" in attrs["units"] + if has_date_units and "bounds" in attrs: + if attrs["bounds"] in variables: + bounds_attrs = variables[attrs["bounds"]].attrs + bounds_attrs.setdefault("units", attrs["units"]) + if "calendar" in attrs: + bounds_attrs.setdefault("calendar", attrs["calendar"]) def _update_bounds_encoding(variables): @@ -376,35 +399,46 @@ def _update_bounds_encoding(variables): for v in variables.values(): attrs = v.attrs encoding = v.encoding - has_date_units = 'units' in encoding and 'since' in encoding['units'] - is_datetime_type = (np.issubdtype(v.dtype, np.datetime64) or - contains_cftime_datetimes(v)) - - if (is_datetime_type and not has_date_units and - 'bounds' in attrs and attrs['bounds'] in variables): - warnings.warn("Variable '{0}' has datetime type and a " - "bounds variable but {0}.encoding does not have " - "units specified. The units encodings for '{0}' " - "and '{1}' will be determined independently " - "and may not be equal, counter to CF-conventions. " - "If this is a concern, specify a units encoding for " - "'{0}' before writing to a file." - .format(v.name, attrs['bounds']), - UserWarning) - - if has_date_units and 'bounds' in attrs: - if attrs['bounds'] in variables: - bounds_encoding = variables[attrs['bounds']].encoding - bounds_encoding.setdefault('units', encoding['units']) - if 'calendar' in encoding: - bounds_encoding.setdefault('calendar', - encoding['calendar']) - - -def decode_cf_variables(variables, attributes, concat_characters=True, - mask_and_scale=True, decode_times=True, - decode_coords=True, drop_variables=None, - use_cftime=None): + has_date_units = "units" in encoding and "since" in encoding["units"] + is_datetime_type = np.issubdtype( + v.dtype, np.datetime64 + ) or contains_cftime_datetimes(v) + + if ( + is_datetime_type + and not has_date_units + and "bounds" in attrs + and attrs["bounds"] in variables + ): + warnings.warn( + "Variable '{0}' has datetime type and a " + "bounds variable but {0}.encoding does not have " + "units specified. The units encodings for '{0}' " + "and '{1}' will be determined independently " + "and may not be equal, counter to CF-conventions. " + "If this is a concern, specify a units encoding for " + "'{0}' before writing to a file.".format(v.name, attrs["bounds"]), + UserWarning, + ) + + if has_date_units and "bounds" in attrs: + if attrs["bounds"] in variables: + bounds_encoding = variables[attrs["bounds"]].encoding + bounds_encoding.setdefault("units", encoding["units"]) + if "calendar" in encoding: + bounds_encoding.setdefault("calendar", encoding["calendar"]) + + +def decode_cf_variables( + variables, + attributes, + concat_characters=True, + mask_and_scale=True, + decode_times=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, +): """ Decode several CF encoded variables. @@ -420,7 +454,7 @@ def stackable(dim): if dim in variables: return False for v in dimensions_used_by[dim]: - if v.dtype.kind != 'S' or dim != v.dims[-1]: + if v.dtype.kind != "S" or dim != v.dims[-1]: return False return True @@ -440,32 +474,47 @@ def stackable(dim): for k, v in variables.items(): if k in drop_variables: continue - stack_char_dim = (concat_characters and v.dtype == 'S1' and - v.ndim > 0 and stackable(v.dims[-1])) + stack_char_dim = ( + concat_characters + and v.dtype == "S1" + and v.ndim > 0 + and stackable(v.dims[-1]) + ) new_vars[k] = decode_cf_variable( - k, v, concat_characters=concat_characters, - mask_and_scale=mask_and_scale, decode_times=decode_times, - stack_char_dim=stack_char_dim, use_cftime=use_cftime) + k, + v, + concat_characters=concat_characters, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + stack_char_dim=stack_char_dim, + use_cftime=use_cftime, + ) if decode_coords: var_attrs = new_vars[k].attrs - if 'coordinates' in var_attrs: - coord_str = var_attrs['coordinates'] + if "coordinates" in var_attrs: + coord_str = var_attrs["coordinates"] var_coord_names = coord_str.split() if all(k in variables for k in var_coord_names): - new_vars[k].encoding['coordinates'] = coord_str - del var_attrs['coordinates'] + new_vars[k].encoding["coordinates"] = coord_str + del var_attrs["coordinates"] coord_names.update(var_coord_names) - if decode_coords and 'coordinates' in attributes: + if decode_coords and "coordinates" in attributes: attributes = OrderedDict(attributes) - coord_names.update(attributes.pop('coordinates').split()) + coord_names.update(attributes.pop("coordinates").split()) return new_vars, attributes, coord_names -def decode_cf(obj, concat_characters=True, mask_and_scale=True, - decode_times=True, decode_coords=True, drop_variables=None, - use_cftime=None): +def decode_cf( + obj, + concat_characters=True, + mask_and_scale=True, + decode_times=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, +): """Decode the given Dataset or Datastore according to CF conventions into a new Dataset. @@ -519,11 +568,18 @@ def decode_cf(obj, concat_characters=True, mask_and_scale=True, file_obj = obj encoding = obj.get_encoding() else: - raise TypeError('can only decode Dataset or DataStore objects') + raise TypeError("can only decode Dataset or DataStore objects") vars, attrs, coord_names = decode_cf_variables( - vars, attrs, concat_characters, mask_and_scale, decode_times, - decode_coords, drop_variables=drop_variables, use_cftime=use_cftime) + vars, + attrs, + concat_characters, + mask_and_scale, + decode_times, + decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) ds._file_obj = file_obj @@ -532,9 +588,13 @@ def decode_cf(obj, concat_characters=True, mask_and_scale=True, return ds -def cf_decoder(variables, attributes, - concat_characters=True, mask_and_scale=True, - decode_times=True): +def cf_decoder( + variables, + attributes, + concat_characters=True, + mask_and_scale=True, + decode_times=True, +): """ Decode a set of CF encoded variables and attributes. @@ -565,7 +625,8 @@ def cf_decoder(variables, attributes, decode_cf_variable """ variables, attributes, _ = decode_cf_variables( - variables, attributes, concat_characters, mask_and_scale, decode_times) + variables, attributes, concat_characters, mask_and_scale, decode_times + ) return variables, attributes @@ -574,12 +635,14 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): non_dim_coord_names = set(non_dim_coord_names) for name in list(non_dim_coord_names): - if isinstance(name, str) and ' ' in name: + if isinstance(name, str) and " " in name: warnings.warn( - 'coordinate {!r} has a space in its name, which means it ' - 'cannot be marked as a coordinate on disk and will be ' - 'saved as a data variable instead'.format(name), - SerializationWarning, stacklevel=6) + "coordinate {!r} has a space in its name, which means it " + "cannot be marked as a coordinate on disk and will be " + "saved as a data variable instead".format(name), + SerializationWarning, + stacklevel=6, + ) non_dim_coord_names.discard(name) global_coordinates = non_dim_coord_names.copy() @@ -587,22 +650,25 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): for coord_name in non_dim_coord_names: target_dims = variables[coord_name].dims for k, v in variables.items(): - if (k not in non_dim_coord_names and k not in v.dims and - set(target_dims) <= set(v.dims)): + if ( + k not in non_dim_coord_names + and k not in v.dims + and set(target_dims) <= set(v.dims) + ): variable_coordinates[k].add(coord_name) global_coordinates.discard(coord_name) - variables = OrderedDict((k, v.copy(deep=False)) - for k, v in variables.items()) + variables = OrderedDict((k, v.copy(deep=False)) for k, v in variables.items()) # These coordinates are saved according to CF conventions for var_name, coord_names in variable_coordinates.items(): attrs = variables[var_name].attrs - if 'coordinates' in attrs: - raise ValueError('cannot serialize coordinates because variable ' - "%s already has an attribute 'coordinates'" - % var_name) - attrs['coordinates'] = ' '.join(map(str, coord_names)) + if "coordinates" in attrs: + raise ValueError( + "cannot serialize coordinates because variable " + "%s already has an attribute 'coordinates'" % var_name + ) + attrs["coordinates"] = " ".join(map(str, coord_names)) # These coordinates are not associated with any particular variables, so we # save them under a global 'coordinates' attribute so xarray can roundtrip @@ -612,10 +678,12 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): # http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/057771.html if global_coordinates: attributes = OrderedDict(attributes) - if 'coordinates' in attributes: - raise ValueError('cannot serialize coordinates because the global ' - "attribute 'coordinates' already exists") - attributes['coordinates'] = ' '.join(map(str, global_coordinates)) + if "coordinates" in attributes: + raise ValueError( + "cannot serialize coordinates because the global " + "attribute 'coordinates' already exists" + ) + attributes["coordinates"] = " ".join(map(str, global_coordinates)) return variables, attributes @@ -637,8 +705,9 @@ def encode_dataset_coordinates(dataset): attrs : dict """ non_dim_coord_names = set(dataset.coords) - set(dataset.dims) - return _encode_coordinates(dataset._variables, dataset.attrs, - non_dim_coord_names=non_dim_coord_names) + return _encode_coordinates( + dataset._variables, dataset.attrs, non_dim_coord_names=non_dim_coord_names + ) def cf_encoder(variables, attributes): @@ -672,17 +741,26 @@ def cf_encoder(variables, attributes): # add encoding for time bounds variables if present. _update_bounds_encoding(variables) - new_vars = OrderedDict((k, encode_cf_variable(v, name=k)) - for k, v in variables.items()) + new_vars = OrderedDict( + (k, encode_cf_variable(v, name=k)) for k, v in variables.items() + ) # Remove attrs from bounds variables (issue #2921) for var in new_vars.values(): - bounds = var.attrs['bounds'] if 'bounds' in var.attrs else None + bounds = var.attrs["bounds"] if "bounds" in var.attrs else None if bounds and bounds in new_vars: # see http://cfconventions.org/cf-conventions/cf-conventions.html#cell-boundaries # noqa - for attr in ['units', 'standard_name', 'axis', 'positive', - 'calendar', 'long_name', 'leap_month', 'leap_year', - 'month_lengths']: + for attr in [ + "units", + "standard_name", + "axis", + "positive", + "calendar", + "long_name", + "leap_month", + "leap_year", + "month_lengths", + ]: if attr in new_vars[bounds].attrs and attr in var.attrs: if new_vars[bounds].attrs[attr] == var.attrs[attr]: new_vars[bounds].attrs.pop(attr) diff --git a/xarray/convert.py b/xarray/convert.py index 83055631bb5..90d259d1af3 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -11,16 +11,43 @@ from .core.dataarray import DataArray from .core.dtypes import get_fill_value -cdms2_ignored_attrs = {'name', 'tileIndex'} -iris_forbidden_keys = {'standard_name', 'long_name', 'units', 'bounds', 'axis', - 'calendar', 'leap_month', 'leap_year', 'month_lengths', - 'coordinates', 'grid_mapping', 'climatology', - 'cell_methods', 'formula_terms', 'compress', - 'missing_value', 'add_offset', 'scale_factor', - 'valid_max', 'valid_min', 'valid_range', '_FillValue'} -cell_methods_strings = {'point', 'sum', 'maximum', 'median', 'mid_range', - 'minimum', 'mean', 'mode', 'standard_deviation', - 'variance'} +cdms2_ignored_attrs = {"name", "tileIndex"} +iris_forbidden_keys = { + "standard_name", + "long_name", + "units", + "bounds", + "axis", + "calendar", + "leap_month", + "leap_year", + "month_lengths", + "coordinates", + "grid_mapping", + "climatology", + "cell_methods", + "formula_terms", + "compress", + "missing_value", + "add_offset", + "scale_factor", + "valid_max", + "valid_min", + "valid_range", + "_FillValue", +} +cell_methods_strings = { + "point", + "sum", + "maximum", + "median", + "mid_range", + "minimum", + "mean", + "mode", + "standard_deviation", + "variance", +} def encode(var): @@ -42,20 +69,22 @@ def from_cdms2(variable): coords = {} for axis in variable.getAxisList(): coords[axis.id] = DataArray( - np.asarray(axis), dims=[axis.id], - attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs)) + np.asarray(axis), + dims=[axis.id], + attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs), + ) grid = variable.getGrid() if grid is not None: ids = [a.id for a in grid.getAxisList()] for axis in grid.getLongitude(), grid.getLatitude(): if axis.id not in variable.getAxisIds(): coords[axis.id] = DataArray( - np.asarray(axis[:]), dims=ids, - attrs=_filter_attrs(axis.attributes, - cdms2_ignored_attrs)) + np.asarray(axis[:]), + dims=ids, + attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs), + ) attrs = _filter_attrs(variable.attributes, cdms2_ignored_attrs) - dataarray = DataArray(values, dims=dims, coords=coords, name=name, - attrs=attrs) + dataarray = DataArray(values, dims=dims, coords=coords, name=name, attrs=attrs) return decode_cf(dataarray.to_dataset())[dataarray.name] @@ -79,8 +108,9 @@ def set_cdms2_attrs(var, attrs): # Data var = encode(dataarray) - cdms2_var = cdms2.createVariable(var.values, axes=axes, id=dataarray.name, - mask=pd.isnull(var.values), copy=copy) + cdms2_var = cdms2.createVariable( + var.values, axes=axes, id=dataarray.name, mask=pd.isnull(var.values), copy=copy + ) # Attributes set_cdms2_attrs(cdms2_var, var.attrs) @@ -93,22 +123,26 @@ def set_cdms2_attrs(var, attrs): coord_array = dataarray.coords[coord_name].to_cdms2() - cdms2_axis_cls = (cdms2.coord.TransientAxis2D - if coord_array.ndim else - cdms2.auxcoord.TransientAuxAxis1D) + cdms2_axis_cls = ( + cdms2.coord.TransientAxis2D + if coord_array.ndim + else cdms2.auxcoord.TransientAuxAxis1D + ) cdms2_axis = cdms2_axis_cls(coord_array) if cdms2_axis.isLongitude(): - cdms2_axes['lon'] = cdms2_axis + cdms2_axes["lon"] = cdms2_axis elif cdms2_axis.isLatitude(): - cdms2_axes['lat'] = cdms2_axis + cdms2_axes["lat"] = cdms2_axis - if 'lon' in cdms2_axes and 'lat' in cdms2_axes: - if len(cdms2_axes['lon'].shape) == 2: + if "lon" in cdms2_axes and "lat" in cdms2_axes: + if len(cdms2_axes["lon"].shape) == 2: cdms2_grid = cdms2.hgrid.TransientCurveGrid( - cdms2_axes['lat'], cdms2_axes['lon']) + cdms2_axes["lat"], cdms2_axes["lon"] + ) else: cdms2_grid = cdms2.gengrid.AbstractGenericGrid( - cdms2_axes['lat'], cdms2_axes['lon']) + cdms2_axes["lat"], cdms2_axes["lon"] + ) for axis in cdms2_grid.getAxisList(): cdms2_var.setAxis(cdms2_var.getAxisIds().index(axis.id), axis) cdms2_var.setGrid(cdms2_grid) @@ -127,11 +161,12 @@ def _get_iris_args(attrs): """ # iris.unit is deprecated in Iris v1.9 import cf_units - args = {'attributes': _filter_attrs(attrs, iris_forbidden_keys)} - args.update(_pick_attrs(attrs, ('standard_name', 'long_name',))) - unit_args = _pick_attrs(attrs, ('calendar',)) - if 'units' in attrs: - args['units'] = cf_units.Unit(attrs['units'], **unit_args) + + args = {"attributes": _filter_attrs(attrs, iris_forbidden_keys)} + args.update(_pick_attrs(attrs, ("standard_name", "long_name"))) + unit_args = _pick_attrs(attrs, ("calendar",)) + if "units" in attrs: + args["units"] = cf_units.Unit(attrs["units"], **unit_args) return args @@ -149,7 +184,7 @@ def to_iris(dataarray): for coord_name in dataarray.coords: coord = encode(dataarray.coords[coord_name]) coord_args = _get_iris_args(coord.attrs) - coord_args['var_name'] = coord_name + coord_args["var_name"] = coord_name axis = None if coord.dims: axis = dataarray.get_axis_num(coord.dims) @@ -165,12 +200,11 @@ def to_iris(dataarray): aux_coords.append((iris_coord, axis)) args = _get_iris_args(dataarray.attrs) - args['var_name'] = dataarray.name - args['dim_coords_and_dims'] = dim_coords - args['aux_coords_and_dims'] = aux_coords - if 'cell_methods' in dataarray.attrs: - args['cell_methods'] = \ - parse_cell_methods(dataarray.attrs['cell_methods']) + args["var_name"] = dataarray.name + args["dim_coords_and_dims"] = dim_coords + args["aux_coords_and_dims"] = aux_coords + if "cell_methods" in dataarray.attrs: + args["cell_methods"] = parse_cell_methods(dataarray.attrs["cell_methods"]) masked_data = duck_array_ops.masked_invalid(dataarray.data) cube = iris.cube.Cube(masked_data, **args) @@ -181,12 +215,11 @@ def to_iris(dataarray): def _iris_obj_to_attrs(obj): """ Return a dictionary of attrs when given a Iris object """ - attrs = {'standard_name': obj.standard_name, - 'long_name': obj.long_name} + attrs = {"standard_name": obj.standard_name, "long_name": obj.long_name} if obj.units.calendar: - attrs['calendar'] = obj.units.calendar - if obj.units.origin != '1' and not obj.units.is_unknown(): - attrs['units'] = obj.units.origin + attrs["calendar"] = obj.units.calendar + if obj.units.origin != "1" and not obj.units.is_unknown(): + attrs["units"] = obj.units.origin attrs.update(obj.attributes) return {k: v for k, v in attrs.items() if v is not None} @@ -196,26 +229,27 @@ def _iris_cell_methods_to_str(cell_methods_obj): """ cell_methods = [] for cell_method in cell_methods_obj: - names = ''.join(['{}: '.format(n) for n in cell_method.coord_names]) - intervals = ' '.join(['interval: {}'.format(interval) - for interval in cell_method.intervals]) - comments = ' '.join(['comment: {}'.format(comment) - for comment in cell_method.comments]) - extra = ' '.join([intervals, comments]).strip() + names = "".join(["{}: ".format(n) for n in cell_method.coord_names]) + intervals = " ".join( + ["interval: {}".format(interval) for interval in cell_method.intervals] + ) + comments = " ".join( + ["comment: {}".format(comment) for comment in cell_method.comments] + ) + extra = " ".join([intervals, comments]).strip() if extra: - extra = ' ({})'.format(extra) + extra = " ({})".format(extra) cell_methods.append(names + cell_method.method + extra) - return ' '.join(cell_methods) + return " ".join(cell_methods) -def _name(iris_obj, default='unknown'): +def _name(iris_obj, default="unknown"): """ Mimicks `iris_obj.name()` but with different name resolution order. Similar to iris_obj.name() method, but using iris_obj.var_name first to enable roundtripping. """ - return (iris_obj.var_name or iris_obj.standard_name or - iris_obj.long_name or default) + return iris_obj.var_name or iris_obj.standard_name or iris_obj.long_name or default def from_iris(cube): @@ -225,7 +259,7 @@ def from_iris(cube): from xarray.core.pycompat import dask_array_type name = _name(cube) - if name == 'unknown': + if name == "unknown": name = None dims = [] for i in range(cube.ndim): @@ -237,7 +271,7 @@ def from_iris(cube): if len(set(dims)) != len(dims): duplicates = [k for k, v in Counter(dims).items() if v > 1] - raise ValueError('Duplicate coordinate name {}.'.format(duplicates)) + raise ValueError("Duplicate coordinate name {}.".format(duplicates)) coords = OrderedDict() @@ -252,21 +286,23 @@ def from_iris(cube): array_attrs = _iris_obj_to_attrs(cube) cell_methods = _iris_cell_methods_to_str(cube.cell_methods) if cell_methods: - array_attrs['cell_methods'] = cell_methods + array_attrs["cell_methods"] = cell_methods # Deal with iris 1.* and 2.* - cube_data = cube.core_data() if hasattr(cube, 'core_data') else cube.data + cube_data = cube.core_data() if hasattr(cube, "core_data") else cube.data # Deal with dask and numpy masked arrays if isinstance(cube_data, dask_array_type): from dask.array import ma as dask_ma + filled_data = dask_ma.filled(cube_data, get_fill_value(cube.dtype)) elif isinstance(cube_data, np.ma.MaskedArray): filled_data = np.ma.filled(cube_data, get_fill_value(cube.dtype)) else: filled_data = cube_data - dataarray = DataArray(filled_data, coords=coords, name=name, - attrs=array_attrs, dims=dims) + dataarray = DataArray( + filled_data, coords=coords, name=name, attrs=array_attrs, dims=dims + ) decoded_ds = decode_cf(dataarray._to_temp_dataset()) return dataarray._from_temp_dataset(decoded_ds) diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 01cddae188f..832eb88c5fa 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -9,7 +9,7 @@ def _season_from_months(months): """Compute season (DJF, MAM, JJA, SON) from month ordinal """ # TODO: Move "season" accessor upstream into pandas - seasons = np.array(['DJF', 'MAM', 'JJA', 'SON']) + seasons = np.array(["DJF", "MAM", "JJA", "SON"]) months = np.asarray(months) return seasons[(months // 3) % 4] @@ -19,8 +19,9 @@ def _access_through_cftimeindex(values, name): and access requested datetime component """ from ..coding.cftimeindex import CFTimeIndex + values_as_cftimeindex = CFTimeIndex(values.ravel()) - if name == 'season': + if name == "season": months = values_as_cftimeindex.month field_values = _season_from_months(months) else: @@ -67,8 +68,8 @@ def _get_date_field(values, name, dtype): if isinstance(values, dask_array_type): from dask.array import map_blocks - return map_blocks(access_method, - values, name, dtype=dtype) + + return map_blocks(access_method, values, name, dtype=dtype) else: return access_method(values, name) @@ -104,8 +105,8 @@ def _round_field(values, name, freq): """ if isinstance(values, dask_array_type): from dask.array import map_blocks - return map_blocks(_round_series, - values, name, freq=freq, dtype=np.datetime64) + + return map_blocks(_round_series, values, name, freq=freq, dtype=np.datetime64) else: return _round_series(values, name, freq) @@ -115,6 +116,7 @@ def _strftime_through_cftimeindex(values, date_format): and access requested datetime component """ from ..coding.cftimeindex import CFTimeIndex + values_as_cftimeindex = CFTimeIndex(values.ravel()) field_values = values_as_cftimeindex.strftime(date_format) @@ -137,6 +139,7 @@ def _strftime(values, date_format): access_method = _strftime_through_cftimeindex if isinstance(values, dask_array_type): from dask.array import map_blocks + return map_blocks(access_method, values, date_format) else: return access_method(values, date_format) @@ -167,10 +170,12 @@ class DatetimeAccessor: def __init__(self, obj): if not _contains_datetime_like_objects(obj): - raise TypeError("'dt' accessor only available for " - "DataArray with datetime64 timedelta64 dtype or " - "for arrays containing cftime datetime " - "objects.") + raise TypeError( + "'dt' accessor only available for " + "DataArray with datetime64 timedelta64 dtype or " + "for arrays containing cftime datetime " + "objects." + ) self._obj = obj def _tslib_field_accessor(name, docstring=None, dtype=None): @@ -179,56 +184,51 @@ def f(self, dtype=dtype): dtype = self._obj.dtype obj_type = type(self._obj) result = _get_date_field(self._obj.data, name, dtype) - return obj_type(result, name=name, - coords=self._obj.coords, dims=self._obj.dims) + return obj_type( + result, name=name, coords=self._obj.coords, dims=self._obj.dims + ) f.__name__ = name f.__doc__ = docstring return property(f) - year = _tslib_field_accessor('year', "The year of the datetime", np.int64) + year = _tslib_field_accessor("year", "The year of the datetime", np.int64) month = _tslib_field_accessor( - 'month', "The month as January=1, December=12", np.int64 - ) - day = _tslib_field_accessor('day', "The days of the datetime", np.int64) - hour = _tslib_field_accessor('hour', "The hours of the datetime", np.int64) - minute = _tslib_field_accessor( - 'minute', "The minutes of the datetime", np.int64 - ) - second = _tslib_field_accessor( - 'second', "The seconds of the datetime", np.int64 + "month", "The month as January=1, December=12", np.int64 ) + day = _tslib_field_accessor("day", "The days of the datetime", np.int64) + hour = _tslib_field_accessor("hour", "The hours of the datetime", np.int64) + minute = _tslib_field_accessor("minute", "The minutes of the datetime", np.int64) + second = _tslib_field_accessor("second", "The seconds of the datetime", np.int64) microsecond = _tslib_field_accessor( - 'microsecond', "The microseconds of the datetime", np.int64 + "microsecond", "The microseconds of the datetime", np.int64 ) nanosecond = _tslib_field_accessor( - 'nanosecond', "The nanoseconds of the datetime", np.int64 + "nanosecond", "The nanoseconds of the datetime", np.int64 ) weekofyear = _tslib_field_accessor( - 'weekofyear', "The week ordinal of the year", np.int64 + "weekofyear", "The week ordinal of the year", np.int64 ) week = weekofyear dayofweek = _tslib_field_accessor( - 'dayofweek', "The day of the week with Monday=0, Sunday=6", np.int64 + "dayofweek", "The day of the week with Monday=0, Sunday=6", np.int64 ) weekday = dayofweek weekday_name = _tslib_field_accessor( - 'weekday_name', "The name of day in a week (ex: Friday)", object + "weekday_name", "The name of day in a week (ex: Friday)", object ) dayofyear = _tslib_field_accessor( - 'dayofyear', "The ordinal day of the year", np.int64 + "dayofyear", "The ordinal day of the year", np.int64 ) - quarter = _tslib_field_accessor('quarter', "The quarter of the date") + quarter = _tslib_field_accessor("quarter", "The quarter of the date") days_in_month = _tslib_field_accessor( - 'days_in_month', "The number of days in the month", np.int64 + "days_in_month", "The number of days in the month", np.int64 ) daysinmonth = days_in_month - season = _tslib_field_accessor( - "season", "Season of the year (ex: DJF)", object - ) + season = _tslib_field_accessor("season", "Season of the year (ex: DJF)", object) time = _tslib_field_accessor( "time", "Timestamps corresponding to datetimes", object @@ -237,11 +237,10 @@ def f(self, dtype=dtype): def _tslib_round_accessor(self, name, freq): obj_type = type(self._obj) result = _round_field(self._obj.data, name, freq) - return obj_type(result, name=name, - coords=self._obj.coords, dims=self._obj.dims) + return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims) def floor(self, freq): - ''' + """ Round timestamps downward to specified frequency resolution. Parameters @@ -253,12 +252,12 @@ def floor(self, freq): ------- floor-ed timestamps : same type as values Array-like of datetime fields accessed for each element in values - ''' + """ return self._tslib_round_accessor("floor", freq) def ceil(self, freq): - ''' + """ Round timestamps upward to specified frequency resolution. Parameters @@ -270,11 +269,11 @@ def ceil(self, freq): ------- ceil-ed timestamps : same type as values Array-like of datetime fields accessed for each element in values - ''' + """ return self._tslib_round_accessor("ceil", freq) def round(self, freq): - ''' + """ Round timestamps to specified frequency resolution. Parameters @@ -286,7 +285,7 @@ def round(self, freq): ------- rounded timestamps : same type as values Array-like of datetime fields accessed for each element in values - ''' + """ return self._tslib_round_accessor("round", freq) def strftime(self, date_format): @@ -320,7 +319,5 @@ def strftime(self, date_format): result = _strftime(self._obj.data, date_format) return obj_type( - result, - name="strftime", - coords=self._obj.coords, - dims=self._obj.dims) + result, name="strftime", coords=self._obj.coords, dims=self._obj.dims + ) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 4a1983517eb..03a6d37b01e 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -46,11 +46,15 @@ from .computation import apply_ufunc _cpython_optimized_encoders = ( - "utf-8", "utf8", "latin-1", "latin1", "iso-8859-1", "mbcs", "ascii" -) -_cpython_optimized_decoders = _cpython_optimized_encoders + ( - "utf-16", "utf-32" + "utf-8", + "utf8", + "latin-1", + "latin1", + "iso-8859-1", + "mbcs", + "ascii", ) +_cpython_optimized_decoders = _cpython_optimized_encoders + ("utf-16", "utf-32") def _is_str_like(x): @@ -80,17 +84,16 @@ def _apply(self, f, dtype=None): dtype = self._obj.dtype g = np.vectorize(f, otypes=[dtype]) - return apply_ufunc( - g, self._obj, dask='parallelized', output_dtypes=[dtype]) + return apply_ufunc(g, self._obj, dask="parallelized", output_dtypes=[dtype]) def len(self): - ''' + """ Compute the length of each element in the array. Returns ------- lengths array : array of int - ''' + """ return self._apply(len, dtype=int) def __getitem__(self, key): @@ -100,7 +103,7 @@ def __getitem__(self, key): return self.get(key) def get(self, i): - ''' + """ Extract element from indexable in each element in the array. Parameters @@ -114,12 +117,12 @@ def get(self, i): Returns ------- items : array of objects - ''' + """ obj = slice(-1, None) if i == -1 else slice(i, i + 1) return self._apply(lambda x: x[obj]) def slice(self, start=None, stop=None, step=None): - ''' + """ Slice substrings from each element in the array. Parameters @@ -134,13 +137,13 @@ def slice(self, start=None, stop=None, step=None): Returns ------- sliced strings : same type as values - ''' + """ s = slice(start, stop, step) f = lambda x: x[s] return self._apply(f) - def slice_replace(self, start=None, stop=None, repl=''): - ''' + def slice_replace(self, start=None, stop=None, repl=""): + """ Replace a positional slice of a string with another value. Parameters @@ -160,7 +163,7 @@ def slice_replace(self, start=None, stop=None, repl=''): Returns ------- replaced : same type as values - ''' + """ repl = self._obj.dtype.type(repl) def f(x): @@ -168,7 +171,7 @@ def f(x): local_stop = start else: local_stop = stop - y = self._obj.dtype.type('') + y = self._obj.dtype.type("") if start is not None: y += x[:start] y += repl @@ -179,156 +182,156 @@ def f(x): return self._apply(f) def capitalize(self): - ''' + """ Convert strings in the array to be capitalized. Returns ------- capitalized : same type as values - ''' + """ return self._apply(lambda x: x.capitalize()) def lower(self): - ''' + """ Convert strings in the array to lowercase. Returns ------- lowerd : same type as values - ''' + """ return self._apply(lambda x: x.lower()) def swapcase(self): - ''' + """ Convert strings in the array to be swapcased. Returns ------- swapcased : same type as values - ''' + """ return self._apply(lambda x: x.swapcase()) def title(self): - ''' + """ Convert strings in the array to titlecase. Returns ------- titled : same type as values - ''' + """ return self._apply(lambda x: x.title()) def upper(self): - ''' + """ Convert strings in the array to uppercase. Returns ------- uppered : same type as values - ''' + """ return self._apply(lambda x: x.upper()) def isalnum(self): - ''' + """ Check whether all characters in each string are alphanumeric. Returns ------- isalnum : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isalnum(), dtype=bool) def isalpha(self): - ''' + """ Check whether all characters in each string are alphabetic. Returns ------- isalpha : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isalpha(), dtype=bool) def isdecimal(self): - ''' + """ Check whether all characters in each string are decimal. Returns ------- isdecimal : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isdecimal(), dtype=bool) def isdigit(self): - ''' + """ Check whether all characters in each string are digits. Returns ------- isdigit : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isdigit(), dtype=bool) def islower(self): - ''' + """ Check whether all characters in each string are lowercase. Returns ------- islower : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.islower(), dtype=bool) def isnumeric(self): - ''' + """ Check whether all characters in each string are numeric. Returns ------- isnumeric : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isnumeric(), dtype=bool) def isspace(self): - ''' + """ Check whether all characters in each string are spaces. Returns ------- isspace : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isspace(), dtype=bool) def istitle(self): - ''' + """ Check whether all characters in each string are titlecase. Returns ------- istitle : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.istitle(), dtype=bool) def isupper(self): - ''' + """ Check whether all characters in each string are uppercase. Returns ------- isupper : array of bool Array of boolean values with the same shape as the original array. - ''' + """ return self._apply(lambda x: x.isupper(), dtype=bool) def count(self, pat, flags=0): - ''' + """ Count occurrences of pattern in each string of the array. This function is used to count the number of times a particular regex @@ -346,14 +349,14 @@ def count(self, pat, flags=0): Returns ------- counts : array of int - ''' + """ pat = self._obj.dtype.type(pat) regex = re.compile(pat, flags=flags) f = lambda x: len(regex.findall(x)) return self._apply(f, dtype=int) def startswith(self, pat): - ''' + """ Test if the start of each string element matches a pattern. Parameters @@ -366,13 +369,13 @@ def startswith(self, pat): startswith : array of bool An array of booleans indicating whether the given pattern matches the start of each string element. - ''' + """ pat = self._obj.dtype.type(pat) f = lambda x: x.startswith(pat) return self._apply(f, dtype=bool) def endswith(self, pat): - ''' + """ Test if the end of each string element matches a pattern. Parameters @@ -385,13 +388,13 @@ def endswith(self, pat): endswith : array of bool A Series of booleans indicating whether the given pattern matches the end of each string element. - ''' + """ pat = self._obj.dtype.type(pat) f = lambda x: x.endswith(pat) return self._apply(f, dtype=bool) - def pad(self, width, side='left', fillchar=' '): - ''' + def pad(self, width, side="left", fillchar=" "): + """ Pad strings in the array up to width. Parameters @@ -408,25 +411,25 @@ def pad(self, width, side='left', fillchar=' '): ------- filled : same type as values Array with a minimum number of char in each element. - ''' + """ width = int(width) fillchar = self._obj.dtype.type(fillchar) if len(fillchar) != 1: - raise TypeError('fillchar must be a character, not str') + raise TypeError("fillchar must be a character, not str") - if side == 'left': + if side == "left": f = lambda s: s.rjust(width, fillchar) - elif side == 'right': + elif side == "right": f = lambda s: s.ljust(width, fillchar) - elif side == 'both': + elif side == "both": f = lambda s: s.center(width, fillchar) else: # pragma: no cover - raise ValueError('Invalid side') + raise ValueError("Invalid side") return self._apply(f) - def center(self, width, fillchar=' '): - ''' + def center(self, width, fillchar=" "): + """ Filling left and right side of strings in the array with an additional character. @@ -441,11 +444,11 @@ def center(self, width, fillchar=' '): Returns ------- filled : same type as values - ''' - return self.pad(width, side='both', fillchar=fillchar) + """ + return self.pad(width, side="both", fillchar=fillchar) - def ljust(self, width, fillchar=' '): - ''' + def ljust(self, width, fillchar=" "): + """ Filling right side of strings in the array with an additional character. @@ -460,11 +463,11 @@ def ljust(self, width, fillchar=' '): Returns ------- filled : same type as values - ''' - return self.pad(width, side='right', fillchar=fillchar) + """ + return self.pad(width, side="right", fillchar=fillchar) - def rjust(self, width, fillchar=' '): - ''' + def rjust(self, width, fillchar=" "): + """ Filling left side of strings in the array with an additional character. Parameters @@ -478,11 +481,11 @@ def rjust(self, width, fillchar=' '): Returns ------- filled : same type as values - ''' - return self.pad(width, side='left', fillchar=fillchar) + """ + return self.pad(width, side="left", fillchar=fillchar) def zfill(self, width): - ''' + """ Pad strings in the array by prepending '0' characters. Strings in the array are padded with '0' characters on the @@ -498,11 +501,11 @@ def zfill(self, width): Returns ------- filled : same type as values - ''' - return self.pad(width, side='left', fillchar='0') + """ + return self.pad(width, side="left", fillchar="0") def contains(self, pat, case=True, flags=0, regex=True): - ''' + """ Test if pattern or regex is contained within a string of the array. Return boolean array based on whether a given pattern or regex is @@ -526,7 +529,7 @@ def contains(self, pat, case=True, flags=0, regex=True): An array of boolean values indicating whether the given pattern is contained within the string of each element of the array. - ''' + """ pat = self._obj.dtype.type(pat) if regex: if not case: @@ -548,7 +551,7 @@ def contains(self, pat, case=True, flags=0, regex=True): return self._apply(f, dtype=bool) def match(self, pat, case=True, flags=0): - ''' + """ Determine if each string matches a regular expression. Parameters @@ -563,7 +566,7 @@ def match(self, pat, case=True, flags=0): Returns ------- matched : array of bool - ''' + """ if not case: flags |= re.IGNORECASE @@ -572,8 +575,8 @@ def match(self, pat, case=True, flags=0): f = lambda x: bool(regex.match(x)) return self._apply(f, dtype=bool) - def strip(self, to_strip=None, side='both'): - ''' + def strip(self, to_strip=None, side="both"): + """ Remove leading and trailing characters. Strip whitespaces (including newlines) or a set of specified characters @@ -591,23 +594,23 @@ def strip(self, to_strip=None, side='both'): Returns ------- stripped : same type as values - ''' + """ if to_strip is not None: to_strip = self._obj.dtype.type(to_strip) - if side == 'both': + if side == "both": f = lambda x: x.strip(to_strip) - elif side == 'left': + elif side == "left": f = lambda x: x.lstrip(to_strip) - elif side == 'right': + elif side == "right": f = lambda x: x.rstrip(to_strip) else: # pragma: no cover - raise ValueError('Invalid side') + raise ValueError("Invalid side") return self._apply(f) def lstrip(self, to_strip=None): - ''' + """ Remove leading and trailing characters. Strip whitespaces (including newlines) or a set of specified characters @@ -623,11 +626,11 @@ def lstrip(self, to_strip=None): Returns ------- stripped : same type as values - ''' - return self.strip(to_strip, side='left') + """ + return self.strip(to_strip, side="left") def rstrip(self, to_strip=None): - ''' + """ Remove leading and trailing characters. Strip whitespaces (including newlines) or a set of specified characters @@ -643,11 +646,11 @@ def rstrip(self, to_strip=None): Returns ------- stripped : same type as values - ''' - return self.strip(to_strip, side='right') + """ + return self.strip(to_strip, side="right") def wrap(self, width, **kwargs): - ''' + """ Wrap long strings in the array to be formatted in paragraphs with length less than a given width. @@ -682,13 +685,13 @@ def wrap(self, width, **kwargs): Returns ------- wrapped : same type as values - ''' + """ tw = textwrap.TextWrapper(width=width) - f = lambda x: '\n'.join(tw.wrap(x)) + f = lambda x: "\n".join(tw.wrap(x)) return self._apply(f) def translate(self, table): - ''' + """ Map all characters in the string through the given mapping table. Parameters @@ -702,12 +705,12 @@ def translate(self, table): Returns ------- translated : same type as values - ''' + """ f = lambda x: x.translate(table) return self._apply(f) def repeat(self, repeats): - ''' + """ Duplicate each string in the array. Parameters @@ -719,12 +722,12 @@ def repeat(self, repeats): ------- repeated : same type as values Array of repeated string objects. - ''' + """ f = lambda x: repeats * x return self._apply(f) - def find(self, sub, start=0, end=None, side='left'): - ''' + def find(self, sub, start=0, end=None, side="left"): + """ Return lowest or highest indexes in each strings in the array where the substring is fully contained between [start:end]. Return -1 on failure. @@ -743,15 +746,15 @@ def find(self, sub, start=0, end=None, side='left'): Returns ------- found : array of integer values - ''' + """ sub = self._obj.dtype.type(sub) - if side == 'left': - method = 'find' - elif side == 'right': - method = 'rfind' + if side == "left": + method = "find" + elif side == "right": + method = "rfind" else: # pragma: no cover - raise ValueError('Invalid side') + raise ValueError("Invalid side") if end is None: f = lambda x: getattr(x, method)(sub, start) @@ -761,7 +764,7 @@ def find(self, sub, start=0, end=None, side='left'): return self._apply(f, dtype=int) def rfind(self, sub, start=0, end=None): - ''' + """ Return highest indexes in each strings in the array where the substring is fully contained between [start:end]. Return -1 on failure. @@ -778,11 +781,11 @@ def rfind(self, sub, start=0, end=None): Returns ------- found : array of integer values - ''' - return self.find(sub, start=start, end=end, side='right') + """ + return self.find(sub, start=start, end=end, side="right") - def index(self, sub, start=0, end=None, side='left'): - ''' + def index(self, sub, start=0, end=None, side="left"): + """ Return lowest or highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as ``str.find`` except instead of returning -1, it raises a ValueError @@ -802,15 +805,15 @@ def index(self, sub, start=0, end=None, side='left'): Returns ------- found : array of integer values - ''' + """ sub = self._obj.dtype.type(sub) - if side == 'left': - method = 'index' - elif side == 'right': - method = 'rindex' + if side == "left": + method = "index" + elif side == "right": + method = "rindex" else: # pragma: no cover - raise ValueError('Invalid side') + raise ValueError("Invalid side") if end is None: f = lambda x: getattr(x, method)(sub, start) @@ -820,7 +823,7 @@ def index(self, sub, start=0, end=None, side='left'): return self._apply(f, dtype=int) def rindex(self, sub, start=0, end=None): - ''' + """ Return highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as ``str.rfind`` except instead of returning -1, it raises a ValueError @@ -838,11 +841,11 @@ def rindex(self, sub, start=0, end=None): Returns ------- found : array of integer values - ''' - return self.index(sub, start=start, end=end, side='right') + """ + return self.index(sub, start=start, end=end, side="right") def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): - ''' + """ Replace occurrences of pattern/regex in the array with some string. Parameters @@ -875,7 +878,7 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): replaced : same type as values A copy of the object with all matching occurrences of `pat` replaced by `repl`. - ''' + """ if not (_is_str_like(repl) or callable(repl)): # pragma: no cover raise TypeError("repl must be a string or callable") @@ -885,12 +888,13 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): if _is_str_like(repl): repl = self._obj.dtype.type(repl) - is_compiled_re = isinstance(pat, type(re.compile(''))) + is_compiled_re = isinstance(pat, type(re.compile(""))) if regex: if is_compiled_re: if (case is not None) or (flags != 0): - raise ValueError("case and flags cannot be set" - " when pat is a compiled regex") + raise ValueError( + "case and flags cannot be set" " when pat is a compiled regex" + ) else: # not a compiled regex # set default case @@ -908,16 +912,19 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): f = lambda x: x.replace(pat, repl, n) else: if is_compiled_re: - raise ValueError("Cannot use a compiled regex as replacement " - "pattern with regex=False") + raise ValueError( + "Cannot use a compiled regex as replacement " + "pattern with regex=False" + ) if callable(repl): - raise ValueError("Cannot use a callable replacement when " - "regex=False") + raise ValueError( + "Cannot use a callable replacement when " "regex=False" + ) f = lambda x: x.replace(pat, repl, n) return self._apply(f) - def decode(self, encoding, errors='strict'): - ''' + def decode(self, encoding, errors="strict"): + """ Decode character string in the array using indicated encoding. Parameters @@ -928,7 +935,7 @@ def decode(self, encoding, errors='strict'): Returns ------- decoded : same type as values - ''' + """ if encoding in _cpython_optimized_decoders: f = lambda x: x.decode(encoding, errors) else: @@ -936,8 +943,8 @@ def decode(self, encoding, errors='strict'): f = lambda x: decoder(x, errors)[0] return self._apply(f, dtype=np.str_) - def encode(self, encoding, errors='strict'): - ''' + def encode(self, encoding, errors="strict"): + """ Encode character string in the array using indicated encoding. Parameters @@ -948,7 +955,7 @@ def encode(self, encoding, errors='strict'): Returns ------- encoded : same type as values - ''' + """ if encoding in _cpython_optimized_encoders: f = lambda x: x.encode(encoding, errors) else: diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 1db9157850a..56f060fd713 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -3,16 +3,7 @@ import warnings from collections import OrderedDict, defaultdict from contextlib import suppress -from typing import ( - Any, - Dict, - Hashable, - Mapping, - Optional, - Tuple, - Union, - TYPE_CHECKING, -) +from typing import Any, Dict, Hashable, Mapping, Optional, Tuple, Union, TYPE_CHECKING import numpy as np import pandas as pd @@ -28,24 +19,30 @@ def _get_joiner(join): - if join == 'outer': + if join == "outer": return functools.partial(functools.reduce, operator.or_) - elif join == 'inner': + elif join == "inner": return functools.partial(functools.reduce, operator.and_) - elif join == 'left': + elif join == "left": return operator.itemgetter(0) - elif join == 'right': + elif join == "right": return operator.itemgetter(-1) - elif join == 'exact': + elif join == "exact": # We cannot return a function to "align" in this case, because it needs # access to the dimension name to give a good error message. return None else: - raise ValueError('invalid value for join: %s' % join) + raise ValueError("invalid value for join: %s" % join) -def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(), - fill_value=dtypes.NA): +def align( + *objects, + join="inner", + copy=True, + indexes=None, + exclude=frozenset(), + fill_value=dtypes.NA +): """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -124,17 +121,23 @@ def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(), for dim, matching_indexes in all_indexes.items(): if dim in indexes: index = utils.safe_cast_to_index(indexes[dim]) - if (any(not index.equals(other) for other in matching_indexes) or - dim in unlabeled_dim_sizes): + if ( + any(not index.equals(other) for other in matching_indexes) + or dim in unlabeled_dim_sizes + ): joined_indexes[dim] = index else: - if (any(not matching_indexes[0].equals(other) - for other in matching_indexes[1:]) or - dim in unlabeled_dim_sizes): - if join == 'exact': + if ( + any( + not matching_indexes[0].equals(other) + for other in matching_indexes[1:] + ) + or dim in unlabeled_dim_sizes + ): + if join == "exact": raise ValueError( - 'indexes along dimension {!r} are not equal' - .format(dim)) + "indexes along dimension {!r} are not equal".format(dim) + ) index = joiner(matching_indexes) joined_indexes[dim] = index else: @@ -145,39 +148,45 @@ def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(), labeled_size = index.size if len(unlabeled_sizes | {labeled_size}) > 1: raise ValueError( - 'arguments without labels along dimension %r cannot be ' - 'aligned because they have different dimension size(s) %r ' - 'than the size of the aligned dimension labels: %r' - % (dim, unlabeled_sizes, labeled_size)) + "arguments without labels along dimension %r cannot be " + "aligned because they have different dimension size(s) %r " + "than the size of the aligned dimension labels: %r" + % (dim, unlabeled_sizes, labeled_size) + ) for dim in unlabeled_dim_sizes: if dim not in all_indexes: sizes = unlabeled_dim_sizes[dim] if len(sizes) > 1: raise ValueError( - 'arguments without labels along dimension %r cannot be ' - 'aligned because they have different dimension sizes: %r' - % (dim, sizes)) + "arguments without labels along dimension %r cannot be " + "aligned because they have different dimension sizes: %r" + % (dim, sizes) + ) result = [] for obj in objects: - valid_indexers = {k: v for k, v in joined_indexes.items() - if k in obj.dims} + valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims} if not valid_indexers: # fast path for no reindexing necessary new_obj = obj.copy(deep=copy) else: - new_obj = obj.reindex(copy=copy, fill_value=fill_value, - **valid_indexers) + new_obj = obj.reindex(copy=copy, fill_value=fill_value, **valid_indexers) new_obj.encoding = obj.encoding result.append(new_obj) return tuple(result) -def deep_align(objects, join='inner', copy=True, indexes=None, - exclude=frozenset(), raise_on_invalid=True, - fill_value=dtypes.NA): +def deep_align( + objects, + join="inner", + copy=True, + indexes=None, + exclude=frozenset(), + raise_on_invalid=True, + fill_value=dtypes.NA, +): """Align objects for merging, recursing into dictionary values. This function is not public API. @@ -214,14 +223,21 @@ def is_alignable(obj): targets.append(v) out.append(OrderedDict(variables)) elif raise_on_invalid: - raise ValueError('object to align is neither an xarray.Dataset, ' - 'an xarray.DataArray nor a dictionary: %r' - % variables) + raise ValueError( + "object to align is neither an xarray.Dataset, " + "an xarray.DataArray nor a dictionary: %r" % variables + ) else: out.append(variables) - aligned = align(*targets, join=join, copy=copy, indexes=indexes, - exclude=exclude, fill_value=fill_value) + aligned = align( + *targets, + join=join, + copy=copy, + indexes=indexes, + exclude=exclude, + fill_value=fill_value + ) for position, key, aligned_obj in zip(positions, keys, aligned): if key is no_key: @@ -236,8 +252,7 @@ def is_alignable(obj): def reindex_like_indexers( - target: Union['DataArray', 'Dataset'], - other: Union['DataArray', 'Dataset'], + target: Union["DataArray", "Dataset"], other: Union["DataArray", "Dataset"] ) -> Dict[Hashable, pd.Index]: """Extract indexers to align target with other. @@ -267,9 +282,11 @@ def reindex_like_indexers( other_size = other.sizes[dim] target_size = target.sizes[dim] if other_size != target_size: - raise ValueError('different size for unlabeled ' - 'dimension on argument %r: %r vs %r' - % (dim, other_size, target_size)) + raise ValueError( + "different size for unlabeled " + "dimension on argument %r: %r vs %r" + % (dim, other_size, target_size) + ) return indexers @@ -282,7 +299,7 @@ def reindex_variables( tolerance: Any = None, copy: bool = True, fill_value: Optional[Any] = dtypes.NA, -) -> 'Tuple[OrderedDict[Any, Variable], OrderedDict[Any, pd.Index]]': +) -> "Tuple[OrderedDict[Any, Variable], OrderedDict[Any, pd.Index]]": """Conform a dictionary of aligned variables onto a new set of variables, filling in missing values with NaN. @@ -343,9 +360,12 @@ def reindex_variables( warnings.warn( "Indexer has dimensions {:s} that are different " "from that to be indexed along {:s}. " - "This will behave differently in the future." - .format(str(indexer.dims), dim), - FutureWarning, stacklevel=3) + "This will behave differently in the future.".format( + str(indexer.dims), dim + ), + FutureWarning, + stacklevel=3, + ) target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim]) @@ -354,8 +374,9 @@ def reindex_variables( if not index.is_unique: raise ValueError( - 'cannot reindex or align along dimension %r because the ' - 'index has duplicate values' % dim) + "cannot reindex or align along dimension %r because the " + "index has duplicate values" % dim + ) int_indexer = get_indexer_nd(index, target, method, tolerance) @@ -381,16 +402,17 @@ def reindex_variables( new_size = indexers[dim].size if existing_size != new_size: raise ValueError( - 'cannot reindex or align along dimension %r without an ' - 'index because its size %r is different from the size of ' - 'the new index %r' % (dim, existing_size, new_size)) + "cannot reindex or align along dimension %r without an " + "index because its size %r is different from the size of " + "the new index %r" % (dim, existing_size, new_size) + ) for name, var in variables.items(): if name not in indexers: - key = tuple(slice(None) - if d in unchanged_dims - else int_indexers.get(d, slice(None)) - for d in var.dims) + key = tuple( + slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None)) + for d in var.dims + ) needs_masking = any(d in masked_dims for d in var.dims) if needs_masking: @@ -441,13 +463,10 @@ def _broadcast_array(array): data = _set_dims(array.variable) coords = OrderedDict(array.coords) coords.update(common_coords) - return DataArray(data, coords, data.dims, name=array.name, - attrs=array.attrs) + return DataArray(data, coords, data.dims, name=array.name, attrs=array.attrs) def _broadcast_dataset(ds): - data_vars = OrderedDict( - (k, _set_dims(ds.variables[k])) - for k in ds.data_vars) + data_vars = OrderedDict((k, _set_dims(ds.variables[k])) for k in ds.data_vars) coords = OrderedDict(ds.coords) coords.update(common_coords) return Dataset(data_vars, coords, ds.attrs) @@ -457,7 +476,7 @@ def _broadcast_dataset(ds): elif isinstance(arg, Dataset): return _broadcast_dataset(arg) else: - raise ValueError('all input must be Dataset or DataArray objects') + raise ValueError("all input must be Dataset or DataArray objects") def broadcast(*args, exclude=None): @@ -535,10 +554,9 @@ def broadcast(*args, exclude=None): if exclude is None: exclude = set() - args = align(*args, join='outer', copy=False, exclude=exclude) + args = align(*args, join="outer", copy=False, exclude=exclude) - dims_map, common_coords = _get_broadcast_dims_map_common_coords( - args, exclude) + dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) result = [] for arg in args: result.append(_broadcast_helper(arg, exclude, dims_map, common_coords)) @@ -548,6 +566,10 @@ def broadcast(*args, exclude=None): def broadcast_arrays(*args): import warnings - warnings.warn('xarray.broadcast_arrays is deprecated: use ' - 'xarray.broadcast instead', DeprecationWarning, stacklevel=2) + + warnings.warn( + "xarray.broadcast_arrays is deprecated: use " "xarray.broadcast instead", + DeprecationWarning, + stacklevel=2, + ) return broadcast(*args) diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 9da4c84697e..5e8c8758ef5 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -19,57 +19,84 @@ class SupportsArithmetic: # numpy.lib.mixins.NDArrayOperatorsMixin. # TODO: allow extending this with some sort of registration system - _HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes, - str) + dask_array_type + _HANDLED_TYPES = ( + np.ndarray, + np.generic, + numbers.Number, + bytes, + str, + ) + dask_array_type def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): from .computation import apply_ufunc # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. - out = kwargs.get('out', ()) + out = kwargs.get("out", ()) for x in inputs + out: if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): return NotImplemented if ufunc.signature is not None: raise NotImplementedError( - '{} not supported: xarray objects do not directly implement ' - 'generalized ufuncs. Instead, use xarray.apply_ufunc or ' - 'explicitly convert to xarray objects to NumPy arrays ' - '(e.g., with `.values`).' - .format(ufunc)) + "{} not supported: xarray objects do not directly implement " + "generalized ufuncs. Instead, use xarray.apply_ufunc or " + "explicitly convert to xarray objects to NumPy arrays " + "(e.g., with `.values`).".format(ufunc) + ) - if method != '__call__': + if method != "__call__": # TODO: support other methods, e.g., reduce and accumulate. raise NotImplementedError( - '{} method for ufunc {} is not implemented on xarray objects, ' - 'which currently only support the __call__ method. As an ' - 'alternative, consider explicitly converting xarray objects ' - 'to NumPy arrays (e.g., with `.values`).' - .format(method, ufunc)) + "{} method for ufunc {} is not implemented on xarray objects, " + "which currently only support the __call__ method. As an " + "alternative, consider explicitly converting xarray objects " + "to NumPy arrays (e.g., with `.values`).".format(method, ufunc) + ) if any(isinstance(o, SupportsArithmetic) for o in out): # TODO: implement this with logic like _inplace_binary_op. This # will be necessary to use NDArrayOperatorsMixin. raise NotImplementedError( - 'xarray objects are not yet supported in the `out` argument ' - 'for ufuncs. As an alternative, consider explicitly ' - 'converting xarray objects to NumPy arrays (e.g., with ' - '`.values`).') + "xarray objects are not yet supported in the `out` argument " + "for ufuncs. As an alternative, consider explicitly " + "converting xarray objects to NumPy arrays (e.g., with " + "`.values`)." + ) - join = dataset_join = OPTIONS['arithmetic_join'] + join = dataset_join = OPTIONS["arithmetic_join"] - return apply_ufunc(ufunc, *inputs, - input_core_dims=((),) * ufunc.nin, - output_core_dims=((),) * ufunc.nout, - join=join, - dataset_join=dataset_join, - dataset_fill_value=np.nan, - kwargs=kwargs, - dask='allowed') + return apply_ufunc( + ufunc, + *inputs, + input_core_dims=((),) * ufunc.nin, + output_core_dims=((),) * ufunc.nout, + join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan, + kwargs=kwargs, + dask="allowed" + ) # this has no runtime function - these are listed so IDEs know these # methods are defined and don't warn on these operations - __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ - __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ - __or__ = __div__ = __eq__ = __ne__ = not_implemented + __lt__ = ( + __le__ + ) = ( + __ge__ + ) = ( + __gt__ + ) = ( + __add__ + ) = ( + __sub__ + ) = ( + __mul__ + ) = ( + __truediv__ + ) = ( + __floordiv__ + ) = ( + __mod__ + ) = ( + __pow__ + ) = __and__ = __xor__ = __or__ = __div__ = __eq__ = __ne__ = not_implemented diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 6a61cb2addc..740cb68c862 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -41,8 +41,7 @@ def _infer_tile_ids_from_nested_list(entry, current_pos): if isinstance(entry, list): for i, item in enumerate(entry): - yield from _infer_tile_ids_from_nested_list( - item, current_pos + (i,)) + yield from _infer_tile_ids_from_nested_list(item, current_pos + (i,)) else: yield current_pos, entry @@ -62,8 +61,10 @@ def _infer_concat_order_from_coords(datasets): # Need to read coordinate values to do ordering indexes = [ds.indexes.get(dim) for ds in datasets] if any(index is None for index in indexes): - raise ValueError("Every dimension needs a coordinate for " - "inferring concatenation order") + raise ValueError( + "Every dimension needs a coordinate for " + "inferring concatenation order" + ) # If dimension coordinate values are same on every dataset then # should be leaving this dimension alone (it's just a "bystander") @@ -77,34 +78,37 @@ def _infer_concat_order_from_coords(datasets): elif all(index.is_monotonic_decreasing for index in indexes): ascending = False else: - raise ValueError("Coordinate variable {} is neither " - "monotonically increasing nor " - "monotonically decreasing on all datasets" - .format(dim)) + raise ValueError( + "Coordinate variable {} is neither " + "monotonically increasing nor " + "monotonically decreasing on all datasets".format(dim) + ) # Assume that any two datasets whose coord along dim starts # with the same value have the same coord values throughout. if any(index.size == 0 for index in indexes): - raise ValueError('Cannot handle size zero dimensions') - first_items = pd.Index([index.take([0]) - for index in indexes]) + raise ValueError("Cannot handle size zero dimensions") + first_items = pd.Index([index.take([0]) for index in indexes]) # Sort datasets along dim # We want rank but with identical elements given identical # position indices - they should be concatenated along another # dimension, not along this one series = first_items.to_series() - rank = series.rank(method='dense', ascending=ascending) + rank = series.rank(method="dense", ascending=ascending) order = rank.astype(int).values - 1 # Append positions along extra dimension to structure which # encodes the multi-dimensional concatentation order - tile_ids = [tile_id + (position,) for tile_id, position - in zip(tile_ids, order)] + tile_ids = [ + tile_id + (position,) for tile_id, position in zip(tile_ids, order) + ] if len(datasets) > 1 and not concat_dims: - raise ValueError("Could not find any dimension coordinates to use to " - "order the datasets for concatenation") + raise ValueError( + "Could not find any dimension coordinates to use to " + "order the datasets for concatenation" + ) combined_ids = OrderedDict(zip(tile_ids, datasets)) @@ -120,22 +124,32 @@ def _check_shape_tile_ids(combined_tile_ids): if not nesting_depths: nesting_depths = [0] if not set(nesting_depths) == {nesting_depths[0]}: - raise ValueError("The supplied objects do not form a hypercube because" - " sub-lists do not have consistent depths") + raise ValueError( + "The supplied objects do not form a hypercube because" + " sub-lists do not have consistent depths" + ) # Check all lists along one dimension are same length for dim in range(nesting_depths[0]): indices_along_dim = [tile_id[dim] for tile_id in tile_ids] occurrences = Counter(indices_along_dim) if len(set(occurrences.values())) != 1: - raise ValueError("The supplied objects do not form a hypercube " - "because sub-lists do not have consistent " - "lengths along dimension" + str(dim)) - - -def _combine_nd(combined_ids, concat_dims, data_vars='all', - coords='different', compat='no_conflicts', - fill_value=dtypes.NA, join='outer'): + raise ValueError( + "The supplied objects do not form a hypercube " + "because sub-lists do not have consistent " + "lengths along dimension" + str(dim) + ) + + +def _combine_nd( + combined_ids, + concat_dims, + data_vars="all", + coords="different", + compat="no_conflicts", + fill_value=dtypes.NA, + join="outer", +): """ Combines an N-dimensional structure of datasets into one by applying a series of either concat and merge operations along each dimension. @@ -163,27 +177,33 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all', n_dims = len(example_tile_id) if len(concat_dims) != n_dims: - raise ValueError("concat_dims has length {} but the datasets " - "passed are nested in a {}-dimensional structure" - .format(len(concat_dims), n_dims)) + raise ValueError( + "concat_dims has length {} but the datasets " + "passed are nested in a {}-dimensional structure".format( + len(concat_dims), n_dims + ) + ) # Each iteration of this loop reduces the length of the tile_ids tuples # by one. It always combines along the first dimension, removing the first # element of the tuple for concat_dim in concat_dims: - combined_ids = _combine_all_along_first_dim(combined_ids, - dim=concat_dim, - data_vars=data_vars, - coords=coords, - compat=compat, - fill_value=fill_value, - join=join) + combined_ids = _combine_all_along_first_dim( + combined_ids, + dim=concat_dim, + data_vars=data_vars, + coords=coords, + compat=compat, + fill_value=fill_value, + join=join, + ) (combined_ds,) = combined_ids.values() return combined_ds -def _combine_all_along_first_dim(combined_ids, dim, data_vars, coords, compat, - fill_value=dtypes.NA, join='outer'): +def _combine_all_along_first_dim( + combined_ids, dim, data_vars, coords, compat, fill_value=dtypes.NA, join="outer" +): # Group into lines of datasets which must be combined along dim # need to sort by _new_tile_id first for groupby to work @@ -196,14 +216,21 @@ def _combine_all_along_first_dim(combined_ids, dim, data_vars, coords, compat, for new_id, group in grouped: combined_ids = OrderedDict(sorted(group)) datasets = combined_ids.values() - new_combined_ids[new_id] = _combine_1d(datasets, dim, compat, - data_vars, coords, fill_value, - join) + new_combined_ids[new_id] = _combine_1d( + datasets, dim, compat, data_vars, coords, fill_value, join + ) return new_combined_ids -def _combine_1d(datasets, concat_dim, compat='no_conflicts', data_vars='all', - coords='different', fill_value=dtypes.NA, join='outer'): +def _combine_1d( + datasets, + concat_dim, + compat="no_conflicts", + data_vars="all", + coords="different", + fill_value=dtypes.NA, + join="outer", +): """ Applies either concat or merge to 1D list of datasets depending on value of concat_dim @@ -211,20 +238,27 @@ def _combine_1d(datasets, concat_dim, compat='no_conflicts', data_vars='all', if concat_dim is not None: try: - combined = concat(datasets, dim=concat_dim, data_vars=data_vars, - coords=coords, fill_value=fill_value, join=join) + combined = concat( + datasets, + dim=concat_dim, + data_vars=data_vars, + coords=coords, + fill_value=fill_value, + join=join, + ) except ValueError as err: if "encountered unexpected variable" in str(err): - raise ValueError("These objects cannot be combined using only " - "xarray.combine_nested, instead either use " - "xarray.combine_by_coords, or do it manually " - "with xarray.concat, xarray.merge and " - "xarray.align") + raise ValueError( + "These objects cannot be combined using only " + "xarray.combine_nested, instead either use " + "xarray.combine_by_coords, or do it manually " + "with xarray.concat, xarray.merge and " + "xarray.align" + ) else: raise else: - combined = merge(datasets, compat=compat, fill_value=fill_value, - join=join) + combined = merge(datasets, compat=compat, fill_value=fill_value, join=join) return combined @@ -234,8 +268,16 @@ def _new_tile_id(single_id_ds_pair): return tile_id[1:] -def _nested_combine(datasets, concat_dims, compat, data_vars, coords, ids, - fill_value=dtypes.NA, join='outer'): +def _nested_combine( + datasets, + concat_dims, + compat, + data_vars, + coords, + ids, + fill_value=dtypes.NA, + join="outer", +): if len(datasets) == 0: return Dataset() @@ -254,15 +296,27 @@ def _nested_combine(datasets, concat_dims, compat, data_vars, coords, ids, _check_shape_tile_ids(combined_ids) # Apply series of concatenate or merge operations along each dimension - combined = _combine_nd(combined_ids, concat_dims, compat=compat, - data_vars=data_vars, coords=coords, - fill_value=fill_value, join=join) + combined = _combine_nd( + combined_ids, + concat_dims, + compat=compat, + data_vars=data_vars, + coords=coords, + fill_value=fill_value, + join=join, + ) return combined -def combine_nested(datasets, concat_dim, compat='no_conflicts', - data_vars='all', coords='different', fill_value=dtypes.NA, - join='outer'): +def combine_nested( + datasets, + concat_dim, + compat="no_conflicts", + data_vars="all", + coords="different", + fill_value=dtypes.NA, + join="outer", +): """ Explicitly combine an N-dimensional grid of datasets into one by using a succession of concat and merge operations along each dimension of the grid. @@ -394,17 +448,30 @@ def combine_nested(datasets, concat_dim, compat='no_conflicts', concat_dim = [concat_dim] # The IDs argument tells _manual_combine that datasets aren't yet sorted - return _nested_combine(datasets, concat_dims=concat_dim, compat=compat, - data_vars=data_vars, coords=coords, ids=False, - fill_value=fill_value, join=join) + return _nested_combine( + datasets, + concat_dims=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + ids=False, + fill_value=fill_value, + join=join, + ) def vars_as_keys(ds): return tuple(sorted(ds)) -def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', - coords='different', fill_value=dtypes.NA, join='outer'): +def combine_by_coords( + datasets, + compat="no_conflicts", + data_vars="all", + coords="different", + fill_value=dtypes.NA, + join="outer", +): """ Attempt to auto-magically combine the given datasets into one by using dimension coordinates. @@ -514,39 +581,56 @@ def combine_by_coords(datasets, compat='no_conflicts', data_vars='all', concatenated_grouped_by_data_vars = [] for vars, datasets_with_same_vars in grouped_by_vars: combined_ids, concat_dims = _infer_concat_order_from_coords( - list(datasets_with_same_vars)) + list(datasets_with_same_vars) + ) _check_shape_tile_ids(combined_ids) # Concatenate along all of concat_dims one by one to create single ds - concatenated = _combine_nd(combined_ids, concat_dims=concat_dims, - data_vars=data_vars, coords=coords, - fill_value=fill_value, join=join) + concatenated = _combine_nd( + combined_ids, + concat_dims=concat_dims, + data_vars=data_vars, + coords=coords, + fill_value=fill_value, + join=join, + ) # Check the overall coordinates are monotonically increasing for dim in concat_dims: indexes = concatenated.indexes.get(dim) - if not (indexes.is_monotonic_increasing - or indexes.is_monotonic_decreasing): - raise ValueError("Resulting object does not have monotonic" - " global indexes along dimension {}" - .format(dim)) + if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing): + raise ValueError( + "Resulting object does not have monotonic" + " global indexes along dimension {}".format(dim) + ) concatenated_grouped_by_data_vars.append(concatenated) - return merge(concatenated_grouped_by_data_vars, compat=compat, - fill_value=fill_value, join=join) + return merge( + concatenated_grouped_by_data_vars, + compat=compat, + fill_value=fill_value, + join=join, + ) # Everything beyond here is only needed until the deprecation cycle in #2616 # is completed -_CONCAT_DIM_DEFAULT = '__infer_concat_dim__' +_CONCAT_DIM_DEFAULT = "__infer_concat_dim__" -def auto_combine(datasets, concat_dim='_not_supplied', compat='no_conflicts', - data_vars='all', coords='different', fill_value=dtypes.NA, - join='outer', from_openmfds=False): +def auto_combine( + datasets, + concat_dim="_not_supplied", + compat="no_conflicts", + data_vars="all", + coords="different", + fill_value=dtypes.NA, + join="outer", + from_openmfds=False, +): """ Attempt to auto-magically combine the given datasets into one. @@ -616,54 +700,71 @@ def auto_combine(datasets, concat_dim='_not_supplied', compat='no_conflicts', """ if not from_openmfds: - basic_msg = dedent("""\ + basic_msg = dedent( + """\ In xarray version 0.13 `auto_combine` will be deprecated. See - http://xarray.pydata.org/en/stable/combining.html#combining-multi""") + http://xarray.pydata.org/en/stable/combining.html#combining-multi""" + ) warnings.warn(basic_msg, FutureWarning, stacklevel=2) - if concat_dim == '_not_supplied': + if concat_dim == "_not_supplied": concat_dim = _CONCAT_DIM_DEFAULT - message = '' + message = "" else: - message = dedent("""\ + message = dedent( + """\ Also `open_mfdataset` will no longer accept a `concat_dim` argument. To get equivalent behaviour from now on please use the new `combine_nested` function instead (or the `combine='nested'` option to - `open_mfdataset`).""") + `open_mfdataset`).""" + ) if _dimension_coords_exist(datasets): - message += dedent("""\ + message += dedent( + """\ The datasets supplied have global dimension coordinates. You may want to use the new `combine_by_coords` function (or the `combine='by_coords'` option to `open_mfdataset`) to order the datasets before concatenation. Alternatively, to continue concatenating based on the order the datasets are supplied in future, please use the new `combine_nested` function (or the `combine='nested'` option to - open_mfdataset).""") + open_mfdataset).""" + ) else: - message += dedent("""\ + message += dedent( + """\ The datasets supplied do not have global dimension coordinates. In future, to continue concatenating without supplying dimension coordinates, please use the new `combine_nested` function (or the - `combine='nested'` option to open_mfdataset.""") + `combine='nested'` option to open_mfdataset.""" + ) if _requires_concat_and_merge(datasets): manual_dims = [concat_dim].append(None) - message += dedent("""\ + message += dedent( + """\ The datasets supplied require both concatenation and merging. From xarray version 0.13 this will operation will require either using the new `combine_nested` function (or the `combine='nested'` option to open_mfdataset), with a nested list structure such that you can combine along the dimensions {}. Alternatively if your datasets have global dimension coordinates then you can use the new `combine_by_coords` - function.""".format(manual_dims)) + function.""".format( + manual_dims + ) + ) warnings.warn(message, FutureWarning, stacklevel=2) - return _old_auto_combine(datasets, concat_dim=concat_dim, - compat=compat, data_vars=data_vars, - coords=coords, fill_value=fill_value, - join=join) + return _old_auto_combine( + datasets, + concat_dim=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + fill_value=fill_value, + join=join, + ) def _dimension_coords_exist(datasets): @@ -701,29 +802,46 @@ def _requires_concat_and_merge(datasets): return len(list(grouped_by_vars)) > 1 -def _old_auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, - compat='no_conflicts', - data_vars='all', coords='different', - fill_value=dtypes.NA, join='outer'): +def _old_auto_combine( + datasets, + concat_dim=_CONCAT_DIM_DEFAULT, + compat="no_conflicts", + data_vars="all", + coords="different", + fill_value=dtypes.NA, + join="outer", +): if concat_dim is not None: dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim sorted_datasets = sorted(datasets, key=vars_as_keys) grouped = itertools.groupby(sorted_datasets, key=vars_as_keys) - concatenated = [_auto_concat(list(datasets), dim=dim, - data_vars=data_vars, coords=coords, - fill_value=fill_value, join=join) - for vars, datasets in grouped] + concatenated = [ + _auto_concat( + list(datasets), + dim=dim, + data_vars=data_vars, + coords=coords, + fill_value=fill_value, + join=join, + ) + for vars, datasets in grouped + ] else: concatenated = datasets - merged = merge(concatenated, compat=compat, fill_value=fill_value, - join=join) + merged = merge(concatenated, compat=compat, fill_value=fill_value, join=join) return merged -def _auto_concat(datasets, dim=None, data_vars='all', coords='different', - fill_value=dtypes.NA, join='outer'): +def _auto_concat( + datasets, + dim=None, + data_vars="all", + coords="different", + fill_value=dtypes.NA, + join="outer", +): if len(datasets) == 1 and dim is None: # There is nothing more to combine, so kick out early. return datasets[0] @@ -736,17 +854,18 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different', dim_tuples = set(ds0.dims.items()) - set(ds1.dims.items()) concat_dims = {i for i, _ in dim_tuples} if len(concat_dims) > 1: - concat_dims = { - d for d in concat_dims - if not ds0[d].equals(ds1[d]) - } + concat_dims = {d for d in concat_dims if not ds0[d].equals(ds1[d])} if len(concat_dims) > 1: - raise ValueError('too many different dimensions to ' - 'concatenate: %s' % concat_dims) + raise ValueError( + "too many different dimensions to " "concatenate: %s" % concat_dims + ) elif len(concat_dims) == 0: - raise ValueError('cannot infer dimension to concatenate: ' - 'supply the ``concat_dim`` argument ' - 'explicitly') + raise ValueError( + "cannot infer dimension to concatenate: " + "supply the ``concat_dim`` argument " + "explicitly" + ) dim, = concat_dims - return concat(datasets, dim=dim, data_vars=data_vars, coords=coords, - fill_value=fill_value) + return concat( + datasets, dim=dim, data_vars=data_vars, coords=coords, fill_value=fill_value + ) diff --git a/xarray/core/common.py b/xarray/core/common.py index 93a5bb71b07..15ce8ca9f04 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -2,8 +2,18 @@ from contextlib import suppress from textwrap import dedent from typing import ( - Any, Callable, Hashable, Iterable, Iterator, List, Mapping, MutableMapping, - Tuple, TypeVar, Union) + Any, + Callable, + Hashable, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Tuple, + TypeVar, + Union, +) import numpy as np import pandas as pd @@ -17,69 +27,84 @@ from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs # Used as a sentinel value to indicate a all dimensions -ALL_DIMS = ReprObject('') +ALL_DIMS = ReprObject("") -C = TypeVar('C') -T = TypeVar('T') +C = TypeVar("C") +T = TypeVar("T") class ImplementsArrayReduce: @classmethod - def _reduce_method(cls, func: Callable, include_skipna: bool, - numeric_only: bool): + def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool): if include_skipna: - def wrapped_func(self, dim=None, axis=None, skipna=None, - **kwargs): - return self.reduce(func, dim, axis, - skipna=skipna, allow_lazy=True, **kwargs) + + def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs): + return self.reduce( + func, dim, axis, skipna=skipna, allow_lazy=True, **kwargs + ) + else: - def wrapped_func(self, dim=None, axis=None, # type: ignore - **kwargs): - return self.reduce(func, dim, axis, - allow_lazy=True, **kwargs) + + def wrapped_func( + self, + dim=None, + axis=None, # type: ignore + **kwargs + ): + return self.reduce(func, dim, axis, allow_lazy=True, **kwargs) + return wrapped_func - _reduce_extra_args_docstring = dedent("""\ + _reduce_extra_args_docstring = dedent( + """\ dim : str or sequence of str, optional Dimension(s) over which to apply `{name}`. axis : int or sequence of int, optional Axis(es) over which to apply `{name}`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then - `{name}` is calculated over axes.""") + `{name}` is calculated over axes.""" + ) - _cum_extra_args_docstring = dedent("""\ + _cum_extra_args_docstring = dedent( + """\ dim : str or sequence of str, optional Dimension over which to apply `{name}`. axis : int or sequence of int, optional Axis over which to apply `{name}`. Only one of the 'dim' - and 'axis' arguments can be supplied.""") + and 'axis' arguments can be supplied.""" + ) class ImplementsDatasetReduce: @classmethod - def _reduce_method(cls, func: Callable, include_skipna: bool, - numeric_only: bool): + def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool): if include_skipna: - def wrapped_func(self, dim=None, skipna=None, - **kwargs): - return self.reduce(func, dim, skipna=skipna, - numeric_only=numeric_only, allow_lazy=True, - **kwargs) + + def wrapped_func(self, dim=None, skipna=None, **kwargs): + return self.reduce( + func, + dim, + skipna=skipna, + numeric_only=numeric_only, + allow_lazy=True, + **kwargs + ) + else: + def wrapped_func(self, dim=None, **kwargs): # type: ignore - return self.reduce(func, dim, - numeric_only=numeric_only, allow_lazy=True, - **kwargs) + return self.reduce( + func, dim, numeric_only=numeric_only, allow_lazy=True, **kwargs + ) + return wrapped_func - _reduce_extra_args_docstring = \ - """dim : str or sequence of str, optional + _reduce_extra_args_docstring = """dim : str or sequence of str, optional Dimension(s) over which to apply `{name}`. By default `{name}` is applied over all dimensions.""" - _cum_extra_args_docstring = \ - """dim : str or sequence of str, optional + _cum_extra_args_docstring = """dim : str or sequence of str, optional Dimension over which to apply `{name}`. axis : int or sequence of int, optional Axis over which to apply `{name}`. Only one of the 'dim' @@ -114,11 +139,12 @@ def _iter(self: Any) -> Iterator[Any]: def __iter__(self: Any) -> Iterator[Any]: if self.ndim == 0: - raise TypeError('iteration over a 0-d array') + raise TypeError("iteration over a 0-d array") return self._iter() - def get_axis_num(self, dim: Union[Hashable, Iterable[Hashable]] - ) -> Union[int, Tuple[int, ...]]: + def get_axis_num( + self, dim: Union[Hashable, Iterable[Hashable]] + ) -> Union[int, Tuple[int, ...]]: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -140,8 +166,7 @@ def _get_axis_num(self: Any, dim: Hashable) -> int: try: return self.dims.index(dim) except ValueError: - raise ValueError("%r not found in array dimensions %r" % - (dim, self.dims)) + raise ValueError("%r not found in array dimensions %r" % (dim, self.dims)) @property def sizes(self: Any) -> Mapping[Hashable, int]: @@ -159,6 +184,7 @@ def sizes(self: Any) -> Mapping[Hashable, int]: class AttrAccessMixin: """Mixin class that allows getting keys with attribute access """ + _initialized = False @property @@ -174,14 +200,15 @@ def _item_sources(self) -> List[Mapping[Hashable, Any]]: return [] def __getattr__(self, name: str) -> Any: - if name != '__setstate__': + if name != "__setstate__": # this avoids an infinite loop when pickle looks for the # __setstate__ attribute before the xarray object is initialized for source in self._attr_sources: with suppress(KeyError): return source[name] - raise AttributeError("%r object has no attribute %r" % - (type(self).__name__, name)) + raise AttributeError( + "%r object has no attribute %r" % (type(self).__name__, name) + ) def __setattr__(self, name: str, value: Any) -> None: if self._initialized: @@ -194,17 +221,20 @@ def __setattr__(self, name: str, value: Any) -> None: raise AttributeError( "cannot set attribute %r on a %r object. Use __setitem__ " "style assignment (e.g., `ds['name'] = ...`) instead to " - "assign variables." % (name, type(self).__name__)) + "assign variables." % (name, type(self).__name__) + ) object.__setattr__(self, name, value) def __dir__(self) -> List[str]: """Provide method name lookup and completion. Only provide 'public' methods. """ - extra_attrs = [item - for sublist in self._attr_sources - for item in sublist - if isinstance(item, str)] + extra_attrs = [ + item + for sublist in self._attr_sources + for item in sublist + if isinstance(item, str) + ] return sorted(set(dir(type(self)) + extra_attrs)) def _ipython_key_completions_(self) -> List[str]: @@ -212,21 +242,24 @@ def _ipython_key_completions_(self) -> List[str]: See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. """ # noqa - item_lists = [item - for sublist in self._item_sources - for item in sublist - if isinstance(item, str)] + item_lists = [ + item + for sublist in self._item_sources + for item in sublist + if isinstance(item, str) + ] return list(set(item_lists)) -def get_squeeze_dims(xarray_obj, - dim: Union[Hashable, Iterable[Hashable], None] = None, - axis: Union[int, Iterable[int], None] = None - ) -> List[Hashable]: +def get_squeeze_dims( + xarray_obj, + dim: Union[Hashable, Iterable[Hashable], None] = None, + axis: Union[int, Iterable[int], None] = None, +) -> List[Hashable]: """Get a list of dimensions to squeeze out. """ if dim is not None and axis is not None: - raise ValueError('cannot use both parameters `axis` and `dim`') + raise ValueError("cannot use both parameters `axis` and `dim`") if dim is None and axis is None: return [d for d, s in xarray_obj.sizes.items() if s == 1] @@ -240,14 +273,15 @@ def get_squeeze_dims(xarray_obj, axis = [axis] axis = list(axis) if any(not isinstance(a, int) for a in axis): - raise TypeError( - 'parameter `axis` must be int or iterable of int.') + raise TypeError("parameter `axis` must be int or iterable of int.") alldims = list(xarray_obj.sizes.keys()) dim = [alldims[a] for a in axis] if any(xarray_obj.sizes[k] > 1 for k in dim): - raise ValueError('cannot select a dimension to squeeze out ' - 'which has length greater than one') + raise ValueError( + "cannot select a dimension to squeeze out " + "which has length greater than one" + ) return dim @@ -256,9 +290,12 @@ class DataWithCoords(SupportsArithmetic, AttrAccessMixin): _rolling_exp_cls = RollingExp - def squeeze(self, dim: Union[Hashable, Iterable[Hashable], None] = None, - drop: bool = False, - axis: Union[int, Iterable[int], None] = None): + def squeeze( + self, + dim: Union[Hashable, Iterable[Hashable], None] = None, + drop: bool = False, + axis: Union[int, Iterable[int], None] = None, + ): """Return a new object with squeezed data. Parameters @@ -299,8 +336,7 @@ def get_index(self, key: Hashable) -> pd.Index: return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64) def _calc_assign_results( - self: C, - kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]] + self: C, kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]] ) -> MutableMapping[Hashable, T]: results = SortedKeysDict() # type: SortedKeysDict[Hashable, T] for k, v in kwargs.items(): @@ -390,8 +426,12 @@ def assign_attrs(self, *args, **kwargs): out.attrs.update(*args, **kwargs) return out - def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]], - *args, **kwargs) -> T: + def pipe( + self, + func: Union[Callable[..., T], Tuple[Callable[..., T], str]], + *args, + **kwargs + ) -> T: """ Apply func(self, *args, **kwargs) @@ -443,15 +483,15 @@ def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]], if isinstance(func, tuple): func, target = func if target in kwargs: - raise ValueError('%s is both the pipe target and a keyword ' - 'argument' % target) + raise ValueError( + "%s is both the pipe target and a keyword " "argument" % target + ) kwargs[target] = self return func(*args, **kwargs) else: return func(self, *args, **kwargs) - def groupby(self, group, squeeze: bool = True, - restore_coord_dims: bool = None): + def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): """Returns a GroupBy object for performing grouped operations. Parameters @@ -498,13 +538,21 @@ def groupby(self, group, squeeze: bool = True, core.groupby.DataArrayGroupBy core.groupby.DatasetGroupBy """ # noqa - return self._groupby_cls(self, group, squeeze=squeeze, - restore_coord_dims=restore_coord_dims) + return self._groupby_cls( + self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) - def groupby_bins(self, group, bins, right: bool = True, labels=None, - precision: int = 3, include_lowest: bool = False, - squeeze: bool = True, - restore_coord_dims: bool = None): + def groupby_bins( + self, + group, + bins, + right: bool = True, + labels=None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool = True, + restore_coord_dims: bool = None, + ): """Returns a GroupBy object for performing grouped operations. Rather than using all unique values of `group`, the values are discretized @@ -553,16 +601,27 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ # noqa - return self._groupby_cls(self, group, squeeze=squeeze, bins=bins, - restore_coord_dims=restore_coord_dims, - cut_kwargs={'right': right, 'labels': labels, - 'precision': precision, - 'include_lowest': - include_lowest}) - - def rolling(self, dim: Mapping[Hashable, int] = None, - min_periods: int = None, center: bool = False, - **window_kwargs: int): + return self._groupby_cls( + self, + group, + squeeze=squeeze, + bins=bins, + restore_coord_dims=restore_coord_dims, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + + def rolling( + self, + dim: Mapping[Hashable, int] = None, + min_periods: int = None, + center: bool = False, + **window_kwargs: int + ): """ Rolling window object. @@ -618,14 +677,13 @@ def rolling(self, dim: Mapping[Hashable, int] = None, core.rolling.DataArrayRolling core.rolling.DatasetRolling """ # noqa - dim = either_dict_or_kwargs(dim, window_kwargs, 'rolling') - return self._rolling_cls(self, dim, min_periods=min_periods, - center=center) + dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") + return self._rolling_cls(self, dim, min_periods=min_periods, center=center) def rolling_exp( self, window: Mapping[Hashable, int] = None, - window_type: str = 'span', + window_type: str = "span", **window_kwargs ): """ @@ -657,15 +715,18 @@ def rolling_exp( -------- core.rolling_exp.RollingExp """ - window = either_dict_or_kwargs(window, window_kwargs, 'rolling_exp') + window = either_dict_or_kwargs(window, window_kwargs, "rolling_exp") return self._rolling_exp_cls(self, window, window_type) - def coarsen(self, dim: Mapping[Hashable, int] = None, - boundary: str = 'exact', - side: Union[str, Mapping[Hashable, str]] = 'left', - coord_func: str = 'mean', - **window_kwargs: int): + def coarsen( + self, + dim: Mapping[Hashable, int] = None, + boundary: str = "exact", + side: Union[str, Mapping[Hashable, str]] = "left", + coord_func: str = "mean", + **window_kwargs: int + ): """ Coarsen object. @@ -719,17 +780,23 @@ def coarsen(self, dim: Mapping[Hashable, int] = None, core.rolling.DataArrayCoarsen core.rolling.DatasetCoarsen """ - dim = either_dict_or_kwargs(dim, window_kwargs, 'coarsen') + dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") return self._coarsen_cls( - self, dim, boundary=boundary, side=side, - coord_func=coord_func) - - def resample(self, indexer: Mapping[Hashable, str] = None, - skipna=None, closed: str = None, - label: str = None, - base: int = 0, keep_attrs: bool = None, - loffset=None, restore_coord_dims: bool = None, - **indexer_kwargs: str): + self, dim, boundary=boundary, side=side, coord_func=coord_func + ) + + def resample( + self, + indexer: Mapping[Hashable, str] = None, + skipna=None, + closed: str = None, + label: str = None, + base: int = 0, + keep_attrs: bool = None, + loffset=None, + restore_coord_dims: bool = None, + **indexer_kwargs: str + ): """Returns a Resample object for performing resampling operations. Handles both downsampling and upsampling. If any intervals contain no @@ -817,19 +884,20 @@ def resample(self, indexer: Mapping[Hashable, str] = None, keep_attrs = _get_keep_attrs(default=False) # note: the second argument (now 'skipna') use to be 'dim' - if ((skipna is not None and not isinstance(skipna, bool)) - or ('how' in indexer_kwargs and 'how' not in self.dims) - or ('dim' in indexer_kwargs and 'dim' not in self.dims)): + if ( + (skipna is not None and not isinstance(skipna, bool)) + or ("how" in indexer_kwargs and "how" not in self.dims) + or ("dim" in indexer_kwargs and "dim" not in self.dims) + ): raise TypeError( - 'resample() no longer supports the `how` or ' - '`dim` arguments. Instead call methods on resample ' - "objects, e.g., data.resample(time='1D').mean()") + "resample() no longer supports the `how` or " + "`dim` arguments. Instead call methods on resample " + "objects, e.g., data.resample(time='1D').mean()" + ) - indexer = either_dict_or_kwargs(indexer, indexer_kwargs, 'resample') + indexer = either_dict_or_kwargs(indexer, indexer_kwargs, "resample") if len(indexer) != 1: - raise ValueError( - "Resampling only supported along single dimensions." - ) + raise ValueError("Resampling only supported along single dimensions.") dim, freq = next(iter(indexer.items())) dim_name = dim @@ -837,19 +905,28 @@ def resample(self, indexer: Mapping[Hashable, str] = None, if isinstance(self.indexes[dim_name], CFTimeIndex): from .resample_cftime import CFTimeGrouper + grouper = CFTimeGrouper(freq, closed, label, base, loffset) else: # TODO: to_offset() call required for pandas==0.19.2 - grouper = pd.Grouper(freq=freq, closed=closed, label=label, - base=base, - loffset=pd.tseries.frequencies.to_offset( - loffset)) - group = DataArray(dim_coord, coords=dim_coord.coords, - dims=dim_coord.dims, name=RESAMPLE_DIM) - resampler = self._resample_cls(self, group=group, dim=dim_name, - grouper=grouper, - resample_dim=RESAMPLE_DIM, - restore_coord_dims=restore_coord_dims) + grouper = pd.Grouper( + freq=freq, + closed=closed, + label=label, + base=base, + loffset=pd.tseries.frequencies.to_offset(loffset), + ) + group = DataArray( + dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM + ) + resampler = self._resample_cls( + self, + group=group, + dim=dim_name, + grouper=grouper, + resample_dim=RESAMPLE_DIM, + restore_coord_dims=restore_coord_dims, + ) return resampler @@ -915,18 +992,20 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): if drop: if other is not dtypes.NA: - raise ValueError('cannot set `other` if drop=True') + raise ValueError("cannot set `other` if drop=True") if not isinstance(cond, (Dataset, DataArray)): - raise TypeError("cond argument is %r but must be a %r or %r" % - (cond, Dataset, DataArray)) + raise TypeError( + "cond argument is %r but must be a %r or %r" + % (cond, Dataset, DataArray) + ) # align so we can use integer indexing self, cond = align(self, cond) # get cond with the minimal size needed for the Dataset if isinstance(cond, Dataset): - clipcond = cond.to_array().any('variable') + clipcond = cond.to_array().any("variable") else: clipcond = cond @@ -981,8 +1060,10 @@ def isin(self, test_elements): if isinstance(test_elements, Dataset): raise TypeError( - 'isin() argument must be convertible to an array: {}' - .format(test_elements)) + "isin() argument must be convertible to an array: {}".format( + test_elements + ) + ) elif isinstance(test_elements, (Variable, DataArray)): # need to explicitly pull out data to support dask arrays as the # second argument @@ -992,7 +1073,7 @@ def isin(self, test_elements): duck_array_ops.isin, self, kwargs=dict(test_elements=test_elements), - dask='allowed', + dask="allowed", ) def __enter__(self: T) -> T: @@ -1033,31 +1114,36 @@ def full_like(other, fill_value, dtype: DTypeLike = None): if isinstance(other, Dataset): data_vars = OrderedDict( (k, _full_like_variable(v, fill_value, dtype)) - for k, v in other.data_vars.items()) + for k, v in other.data_vars.items() + ) return Dataset(data_vars, coords=other.coords, attrs=other.attrs) elif isinstance(other, DataArray): return DataArray( _full_like_variable(other.variable, fill_value, dtype), - dims=other.dims, coords=other.coords, attrs=other.attrs, - name=other.name) + dims=other.dims, + coords=other.coords, + attrs=other.attrs, + name=other.name, + ) elif isinstance(other, Variable): return _full_like_variable(other, fill_value, dtype) else: raise TypeError("Expected DataArray, Dataset, or Variable") -def _full_like_variable(other, fill_value, - dtype: DTypeLike = None): +def _full_like_variable(other, fill_value, dtype: DTypeLike = None): """Inner function of full_like, where other must be a variable """ from .variable import Variable if isinstance(other.data, dask_array_type): import dask.array + if dtype is None: dtype = other.dtype - data = dask.array.full(other.shape, fill_value, dtype=dtype, - chunks=other.data.chunks) + data = dask.array.full( + other.shape, fill_value, dtype=dtype, chunks=other.data.chunks + ) else: data = np.full_like(other, fill_value, dtype=dtype) @@ -1079,8 +1165,7 @@ def ones_like(other, dtype: DTypeLike = None): def is_np_datetime_like(dtype: DTypeLike) -> bool: """Check if a dtype is a subclass of the numpy datetime types """ - return (np.issubdtype(dtype, np.datetime64) or - np.issubdtype(dtype, np.timedelta64)) + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) def _contains_cftime_datetimes(array) -> bool: @@ -1091,7 +1176,7 @@ def _contains_cftime_datetimes(array) -> bool: except ImportError: return False else: - if array.dtype == np.dtype('O') and array.size > 0: + if array.dtype == np.dtype("O") and array.size > 0: sample = array.ravel()[0] if isinstance(sample, dask_array_type): sample = sample.compute() diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7ccfeae2219..ee47f3593c4 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -7,8 +7,18 @@ from collections import Counter, OrderedDict from distutils.version import LooseVersion from typing import ( - AbstractSet, Any, Callable, Iterable, List, Mapping, Optional, Sequence, - Tuple, Union, TYPE_CHECKING) + AbstractSet, + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) import numpy as np @@ -23,9 +33,9 @@ from .dataset import Dataset _DEFAULT_FROZEN_SET = frozenset() # type: frozenset -_NO_FILL_VALUE = utils.ReprObject('') -_DEFAULT_NAME = utils.ReprObject('') -_JOINS_WITHOUT_FILL_VALUES = frozenset({'inner', 'exact'}) +_NO_FILL_VALUE = utils.ReprObject("") +_DEFAULT_NAME = utils.ReprObject("") +_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) class _UFuncSignature: @@ -52,21 +62,22 @@ def __init__(self, input_core_dims, output_core_dims=((),)): def all_input_core_dims(self): if self._all_input_core_dims is None: self._all_input_core_dims = frozenset( - dim for dims in self.input_core_dims for dim in dims) + dim for dims in self.input_core_dims for dim in dims + ) return self._all_input_core_dims @property def all_output_core_dims(self): if self._all_output_core_dims is None: self._all_output_core_dims = frozenset( - dim for dims in self.output_core_dims for dim in dims) + dim for dims in self.output_core_dims for dim in dims + ) return self._all_output_core_dims @property def all_core_dims(self): if self._all_core_dims is None: - self._all_core_dims = (self.all_input_core_dims | - self.all_output_core_dims) + self._all_core_dims = self.all_input_core_dims | self.all_output_core_dims return self._all_core_dims @property @@ -79,8 +90,10 @@ def num_outputs(self): def __eq__(self, other): try: - return (self.input_core_dims == other.input_core_dims and - self.output_core_dims == other.output_core_dims) + return ( + self.input_core_dims == other.input_core_dims + and self.output_core_dims == other.output_core_dims + ) except AttributeError: return False @@ -88,17 +101,16 @@ def __ne__(self, other): return not self == other def __repr__(self): - return ('%s(%r, %r)' - % (type(self).__name__, - list(self.input_core_dims), - list(self.output_core_dims))) + return "%s(%r, %r)" % ( + type(self).__name__, + list(self.input_core_dims), + list(self.output_core_dims), + ) def __str__(self): - lhs = ','.join('({})'.format(','.join(dims)) - for dims in self.input_core_dims) - rhs = ','.join('({})'.format(','.join(dims)) - for dims in self.output_core_dims) - return '{}->{}'.format(lhs, rhs) + lhs = ",".join("({})".format(",".join(dims)) for dims in self.input_core_dims) + rhs = ",".join("({})".format(",".join(dims)) for dims in self.output_core_dims) + return "{}->{}".format(lhs, rhs) def to_gufunc_string(self): """Create an equivalent signature string for a NumPy gufunc. @@ -108,10 +120,14 @@ def to_gufunc_string(self): """ all_dims = self.all_core_dims dims_map = dict(zip(sorted(all_dims), range(len(all_dims)))) - input_core_dims = [['dim%d' % dims_map[dim] for dim in core_dims] - for core_dims in self.input_core_dims] - output_core_dims = [['dim%d' % dims_map[dim] for dim in core_dims] - for core_dims in self.output_core_dims] + input_core_dims = [ + ["dim%d" % dims_map[dim] for dim in core_dims] + for core_dims in self.input_core_dims + ] + output_core_dims = [ + ["dim%d" % dims_map[dim] for dim in core_dims] + for core_dims in self.output_core_dims + ] alt_signature = type(self)(input_core_dims, output_core_dims) return str(alt_signature) @@ -119,7 +135,7 @@ def to_gufunc_string(self): def result_name(objects: list) -> Any: # use the same naming heuristics as pandas: # https://github.com/blaze/blaze/issues/458#issuecomment-51936356 - names = {getattr(obj, 'name', _DEFAULT_NAME) for obj in objects} + names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects} names.discard(_DEFAULT_NAME) if len(names) == 1: name, = names @@ -136,16 +152,14 @@ def _get_coord_variables(args): except AttributeError: pass # skip this argument else: - coord_vars = getattr(coords, 'variables', coords) + coord_vars = getattr(coords, "variables", coords) input_coords.append(coord_vars) return input_coords def build_output_coords( - args: list, - signature: _UFuncSignature, - exclude_dims: AbstractSet = frozenset(), -) -> 'List[OrderedDict[Any, Variable]]': + args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset() +) -> "List[OrderedDict[Any, Variable]]": """Build output coordinates for an operation. Parameters @@ -166,9 +180,12 @@ def build_output_coords( input_coords = _get_coord_variables(args) if exclude_dims: - input_coords = [OrderedDict((k, v) for k, v in coord_vars.items() - if exclude_dims.isdisjoint(v.dims)) - for coord_vars in input_coords] + input_coords = [ + OrderedDict( + (k, v) for k, v in coord_vars.items() if exclude_dims.isdisjoint(v.dims) + ) + for coord_vars in input_coords + ] if len(input_coords) == 1: # we can skip the expensive merge @@ -181,8 +198,9 @@ def build_output_coords( for output_dims in signature.output_core_dims: dropped_dims = signature.all_input_core_dims - set(output_dims) if dropped_dims: - filtered = OrderedDict((k, v) for k, v in merged.items() - if dropped_dims.isdisjoint(v.dims)) + filtered = OrderedDict( + (k, v) for k, v in merged.items() if dropped_dims.isdisjoint(v.dims) + ) else: filtered = merged output_coords.append(filtered) @@ -191,12 +209,7 @@ def build_output_coords( def apply_dataarray_vfunc( - func, - *args, - signature, - join='inner', - exclude_dims=frozenset(), - keep_attrs=False + func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False ): """Apply a variable level function over DataArray, Variable and/or ndarray objects. @@ -204,21 +217,24 @@ def apply_dataarray_vfunc( from .dataarray import DataArray if len(args) > 1: - args = deep_align(args, join=join, copy=False, exclude=exclude_dims, - raise_on_invalid=False) + args = deep_align( + args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + ) - if keep_attrs and hasattr(args[0], 'name'): + if keep_attrs and hasattr(args[0], "name"): name = args[0].name else: name = result_name(args) result_coords = build_output_coords(args, signature, exclude_dims) - data_vars = [getattr(a, 'variable', a) for a in args] + data_vars = [getattr(a, "variable", a) for a in args] result_var = func(*data_vars) if signature.num_outputs > 1: - out = tuple(DataArray(variable, coords, name=name, fastpath=True) - for variable, coords in zip(result_var, result_coords)) + out = tuple( + DataArray(variable, coords, name=name, fastpath=True) + for variable, coords in zip(result_var, result_coords) + ) else: coords, = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) @@ -246,38 +262,36 @@ def assert_and_return_exact_match(all_keys): for keys in all_keys[1:]: if keys != first_keys: raise ValueError( - 'exact match required for all data variable names, ' - 'but %r != %r' % (keys, first_keys)) + "exact match required for all data variable names, " + "but %r != %r" % (keys, first_keys) + ) return first_keys _JOINERS = { - 'inner': ordered_set_intersection, - 'outer': ordered_set_union, - 'left': operator.itemgetter(0), - 'right': operator.itemgetter(-1), - 'exact': assert_and_return_exact_match, + "inner": ordered_set_intersection, + "outer": ordered_set_union, + "left": operator.itemgetter(0), + "right": operator.itemgetter(-1), + "exact": assert_and_return_exact_match, } def join_dict_keys( - objects: Iterable[Union[Mapping, Any]], how: str = 'inner', + objects: Iterable[Union[Mapping, Any]], how: str = "inner" ) -> Iterable: joiner = _JOINERS[how] - all_keys = [obj.keys() for obj in objects if hasattr(obj, 'keys')] + all_keys = [obj.keys() for obj in objects if hasattr(obj, "keys")] return joiner(all_keys) def collect_dict_values( - objects: Iterable[Union[Mapping, Any]], - keys: Iterable, - fill_value: object = None, + objects: Iterable[Union[Mapping, Any]], keys: Iterable, fill_value: object = None ) -> List[list]: - return [[obj.get(key, fill_value) - if is_dict_like(obj) - else obj - for obj in objects] - for key in keys] + return [ + [obj.get(key, fill_value) if is_dict_like(obj) else obj for obj in objects] + for key in keys + ] def _as_variables_or_variable(arg): @@ -291,9 +305,8 @@ def _as_variables_or_variable(arg): def _unpack_dict_tuples( - result_vars: Mapping[Any, Tuple[Variable]], - num_outputs: int, -) -> 'Tuple[OrderedDict[Any, Variable], ...]': + result_vars: Mapping[Any, Tuple[Variable]], num_outputs: int +) -> "Tuple[OrderedDict[Any, Variable], ...]": out = tuple(OrderedDict() for _ in range(num_outputs)) # type: ignore for name, values in result_vars.items(): for value, results_dict in zip(values, out): @@ -302,7 +315,7 @@ def _unpack_dict_tuples( def apply_dict_of_variables_vfunc( - func, *args, signature, join='inner', fill_value=None + func, *args, signature, join="inner", fill_value=None ): """Apply a variable level function over dicts of DataArray, DataArray, Variable and ndarray objects. @@ -322,14 +335,14 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: 'OrderedDict[Any, Variable]', - coord_variables: Mapping[Any, Variable], -) -> 'Dataset': + variables: "OrderedDict[Any, Variable]", coord_variables: Mapping[Any, Variable] +) -> "Dataset": """Create a dataset as quickly as possible. Beware: the `variables` OrderedDict is modified INPLACE. """ from .dataset import Dataset + variables.update(coord_variables) coord_names = set(coord_variables) return Dataset._from_vars_and_coord_names(variables, coord_names) @@ -339,8 +352,8 @@ def apply_dataset_vfunc( func, *args, signature, - join='inner', - dataset_join='exact', + join="inner", + dataset_join="exact", fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), keep_attrs=False @@ -349,28 +362,30 @@ def apply_dataset_vfunc( DataArray, Variable and/or ndarray objects. """ from .dataset import Dataset + first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True - if (dataset_join not in _JOINS_WITHOUT_FILL_VALUES and - fill_value is _NO_FILL_VALUE): - raise TypeError('to apply an operation to datasets with different ' - 'data variables with apply_ufunc, you must supply the ' - 'dataset_fill_value argument.') + if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE: + raise TypeError( + "to apply an operation to datasets with different " + "data variables with apply_ufunc, you must supply the " + "dataset_fill_value argument." + ) if len(args) > 1: - args = deep_align(args, join=join, copy=False, exclude=exclude_dims, - raise_on_invalid=False) + args = deep_align( + args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + ) list_of_coords = build_output_coords(args, signature, exclude_dims) - args = [getattr(arg, 'data_vars', arg) for arg in args] + args = [getattr(arg, "data_vars", arg) for arg in args] result_vars = apply_dict_of_variables_vfunc( - func, *args, signature=signature, join=dataset_join, - fill_value=fill_value) + func, *args, signature=signature, join=dataset_join, fill_value=fill_value + ) if signature.num_outputs > 1: - out = tuple(_fast_dataset(*args) - for args in zip(result_vars, list_of_coords)) + out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords)) else: coord_vars, = list_of_coords out = _fast_dataset(result_vars, coord_vars) @@ -406,12 +421,14 @@ def apply_groupby_func(func, *args): from .variable import Variable groupbys = [arg for arg in args if isinstance(arg, GroupBy)] - assert groupbys, 'must have at least one groupby to iterate over' + assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): - raise ValueError('apply_ufunc can only perform operations over ' - 'multiple GroupBy objets at once if they are all ' - 'grouped the same way') + raise ValueError( + "apply_ufunc can only perform operations over " + "multiple GroupBy objets at once if they are all " + "grouped the same way" + ) grouped_dim = first_groupby._group.name unique_values = first_groupby._unique_coord.values @@ -420,12 +437,13 @@ def apply_groupby_func(func, *args): for arg in args: if isinstance(arg, GroupBy): iterator = (value for _, value in arg) - elif hasattr(arg, 'dims') and grouped_dim in arg.dims: + elif hasattr(arg, "dims") and grouped_dim in arg.dims: if isinstance(arg, Variable): raise ValueError( - 'groupby operations cannot be performed with ' - 'xarray.Variable objects that share a dimension with ' - 'the grouped dimension') + "groupby operations cannot be performed with " + "xarray.Variable objects that share a dimension with " + "the grouped dimension" + ) iterator = _iter_over_selections(arg, grouped_dim, unique_values) else: iterator = itertools.repeat(arg) @@ -442,25 +460,27 @@ def apply_groupby_func(func, *args): def unified_dim_sizes( - variables: Iterable[Variable], - exclude_dims: AbstractSet = frozenset() -) -> 'OrderedDict[Any, int]': + variables: Iterable[Variable], exclude_dims: AbstractSet = frozenset() +) -> "OrderedDict[Any, int]": dim_sizes = OrderedDict() # type: OrderedDict[Any, int] for var in variables: if len(set(var.dims)) < len(var.dims): - raise ValueError('broadcasting cannot handle duplicate ' - 'dimensions on a variable: %r' % list(var.dims)) + raise ValueError( + "broadcasting cannot handle duplicate " + "dimensions on a variable: %r" % list(var.dims) + ) for dim, size in zip(var.dims, var.shape): if dim not in exclude_dims: if dim not in dim_sizes: dim_sizes[dim] = size elif dim_sizes[dim] != size: - raise ValueError('operands cannot be broadcast together ' - 'with mismatched lengths for dimension ' - '%r: %s vs %s' - % (dim, dim_sizes[dim], size)) + raise ValueError( + "operands cannot be broadcast together " + "with mismatched lengths for dimension " + "%r: %s vs %s" % (dim, dim_sizes[dim], size) + ) return dim_sizes @@ -482,17 +502,19 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims): missing_core_dims = [d for d in core_dims if d not in set_old_dims] if missing_core_dims: raise ValueError( - 'operand to apply_ufunc has required core dimensions %r, but ' - 'some of these are missing on the input variable: %r' - % (list(core_dims), missing_core_dims)) + "operand to apply_ufunc has required core dimensions %r, but " + "some of these are missing on the input variable: %r" + % (list(core_dims), missing_core_dims) + ) set_new_dims = set(new_dims) unexpected_dims = [d for d in old_dims if d not in set_new_dims] if unexpected_dims: - raise ValueError('operand to apply_ufunc encountered unexpected ' - 'dimensions %r on an input variable: these are core ' - 'dimensions on other input or output variables' - % unexpected_dims) + raise ValueError( + "operand to apply_ufunc encountered unexpected " + "dimensions %r on an input variable: these are core " + "dimensions on other input or output variables" % unexpected_dims + ) # for consistency with numpy, keep broadcast dimensions to the left old_broadcast_dims = tuple(d for d in broadcast_dims if d in set_old_dims) @@ -520,7 +542,7 @@ def apply_variable_ufunc( *args, signature, exclude_dims=frozenset(), - dask='forbidden', + dask="forbidden", output_dtypes=None, output_sizes=None, keep_attrs=False @@ -529,67 +551,89 @@ def apply_variable_ufunc( """ from .variable import Variable, as_compatible_data - dim_sizes = unified_dim_sizes((a for a in args if hasattr(a, 'dims')), - exclude_dims=exclude_dims) - broadcast_dims = tuple(dim for dim in dim_sizes - if dim not in signature.all_core_dims) + dim_sizes = unified_dim_sizes( + (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims + ) + broadcast_dims = tuple( + dim for dim in dim_sizes if dim not in signature.all_core_dims + ) output_dims = [broadcast_dims + out for out in signature.output_core_dims] - input_data = [broadcast_compat_data(arg, broadcast_dims, core_dims) - if isinstance(arg, Variable) - else arg - for arg, core_dims in zip(args, signature.input_core_dims)] + input_data = [ + broadcast_compat_data(arg, broadcast_dims, core_dims) + if isinstance(arg, Variable) + else arg + for arg, core_dims in zip(args, signature.input_core_dims) + ] if any(isinstance(array, dask_array_type) for array in input_data): - if dask == 'forbidden': - raise ValueError('apply_ufunc encountered a dask array on an ' - 'argument, but handling for dask arrays has not ' - 'been enabled. Either set the ``dask`` argument ' - 'or load your data into memory first with ' - '``.load()`` or ``.compute()``') - elif dask == 'parallelized': - input_dims = [broadcast_dims + dims - for dims in signature.input_core_dims] + if dask == "forbidden": + raise ValueError( + "apply_ufunc encountered a dask array on an " + "argument, but handling for dask arrays has not " + "been enabled. Either set the ``dask`` argument " + "or load your data into memory first with " + "``.load()`` or ``.compute()``" + ) + elif dask == "parallelized": + input_dims = [broadcast_dims + dims for dims in signature.input_core_dims] numpy_func = func def func(*arrays): return _apply_blockwise( - numpy_func, arrays, input_dims, output_dims, - signature, output_dtypes, output_sizes) - elif dask == 'allowed': + numpy_func, + arrays, + input_dims, + output_dims, + signature, + output_dtypes, + output_sizes, + ) + + elif dask == "allowed": pass else: - raise ValueError('unknown setting for dask array handling in ' - 'apply_ufunc: {}'.format(dask)) + raise ValueError( + "unknown setting for dask array handling in " + "apply_ufunc: {}".format(dask) + ) result_data = func(*input_data) if signature.num_outputs == 1: result_data = (result_data,) - elif (not isinstance(result_data, tuple) or - len(result_data) != signature.num_outputs): - raise ValueError('applied function does not have the number of ' - 'outputs specified in the ufunc signature. ' - 'Result is not a tuple of {} elements: {!r}' - .format(signature.num_outputs, result_data)) + elif ( + not isinstance(result_data, tuple) or len(result_data) != signature.num_outputs + ): + raise ValueError( + "applied function does not have the number of " + "outputs specified in the ufunc signature. " + "Result is not a tuple of {} elements: {!r}".format( + signature.num_outputs, result_data + ) + ) output = [] for dims, data in zip(output_dims, result_data): data = as_compatible_data(data) if data.ndim != len(dims): raise ValueError( - 'applied function returned data with unexpected ' - 'number of dimensions: {} vs {}, for dimensions {}' - .format(data.ndim, len(dims), dims)) + "applied function returned data with unexpected " + "number of dimensions: {} vs {}, for dimensions {}".format( + data.ndim, len(dims), dims + ) + ) var = Variable(dims, data, fastpath=True) for dim, new_size in var.sizes.items(): if dim in dim_sizes and new_size != dim_sizes[dim]: raise ValueError( - 'size of dimension {!r} on inputs was unexpectedly ' - 'changed by applied function from {} to {}. Only ' - 'dimensions specified in ``exclude_dims`` with ' - 'xarray.apply_ufunc are allowed to change size.' - .format(dim, dim_sizes[dim], new_size)) + "size of dimension {!r} on inputs was unexpectedly " + "changed by applied function from {} to {}. Only " + "dimensions specified in ``exclude_dims`` with " + "xarray.apply_ufunc are allowed to change size.".format( + dim, dim_sizes[dim], new_size + ) + ) if keep_attrs and isinstance(args[0], Variable): var.attrs.update(args[0].attrs) @@ -601,25 +645,35 @@ def func(*arrays): return tuple(output) -def _apply_blockwise(func, args, input_dims, output_dims, signature, - output_dtypes, output_sizes=None): +def _apply_blockwise( + func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None +): import dask.array as da from .dask_array_compat import blockwise if signature.num_outputs > 1: - raise NotImplementedError('multiple outputs from apply_ufunc not yet ' - "supported with dask='parallelized'") + raise NotImplementedError( + "multiple outputs from apply_ufunc not yet " + "supported with dask='parallelized'" + ) if output_dtypes is None: - raise ValueError('output dtypes (output_dtypes) must be supplied to ' - "apply_func when using dask='parallelized'") + raise ValueError( + "output dtypes (output_dtypes) must be supplied to " + "apply_func when using dask='parallelized'" + ) if not isinstance(output_dtypes, list): - raise TypeError('output_dtypes must be a list of objects coercible to ' - 'numpy dtypes, got {}'.format(output_dtypes)) + raise TypeError( + "output_dtypes must be a list of objects coercible to " + "numpy dtypes, got {}".format(output_dtypes) + ) if len(output_dtypes) != signature.num_outputs: - raise ValueError('apply_ufunc arguments output_dtypes and ' - 'output_core_dims must have the same length: {} vs {}' - .format(len(output_dtypes), signature.num_outputs)) + raise ValueError( + "apply_ufunc arguments output_dtypes and " + "output_core_dims must have the same length: {} vs {}".format( + len(output_dtypes), signature.num_outputs + ) + ) (dtype,) = output_dtypes if output_sizes is None: @@ -627,56 +681,67 @@ def _apply_blockwise(func, args, input_dims, output_dims, signature, new_dims = signature.all_output_core_dims - signature.all_input_core_dims if any(dim not in output_sizes for dim in new_dims): - raise ValueError("when using dask='parallelized' with apply_ufunc, " - 'output core dimensions not found on inputs must ' - 'have explicitly set sizes with ``output_sizes``: {}' - .format(new_dims)) + raise ValueError( + "when using dask='parallelized' with apply_ufunc, " + "output core dimensions not found on inputs must " + "have explicitly set sizes with ``output_sizes``: {}".format(new_dims) + ) - for n, (data, core_dims) in enumerate( - zip(args, signature.input_core_dims)): + for n, (data, core_dims) in enumerate(zip(args, signature.input_core_dims)): if isinstance(data, dask_array_type): # core dimensions cannot span multiple chunks for axis, dim in enumerate(core_dims, start=-len(core_dims)): if len(data.chunks[axis]) != 1: raise ValueError( - 'dimension {!r} on {}th function argument to ' + "dimension {!r} on {}th function argument to " "apply_ufunc with dask='parallelized' consists of " - 'multiple chunks, but is also a core dimension. To ' - 'fix, rechunk into a single dask array chunk along ' - 'this dimension, i.e., ``.chunk({})``, but beware ' - 'that this may significantly increase memory usage.' - .format(dim, n, {dim: -1})) + "multiple chunks, but is also a core dimension. To " + "fix, rechunk into a single dask array chunk along " + "this dimension, i.e., ``.chunk({})``, but beware " + "that this may significantly increase memory usage.".format( + dim, n, {dim: -1} + ) + ) (out_ind,) = output_dims blockwise_args = [] for arg, dims in zip(args, input_dims): # skip leading dimensions that are implicitly added by broadcasting - ndim = getattr(arg, 'ndim', 0) + ndim = getattr(arg, "ndim", 0) trimmed_dims = dims[-ndim:] if ndim else () blockwise_args.extend([arg, trimmed_dims]) - return blockwise(func, out_ind, *blockwise_args, dtype=dtype, - concatenate=True, new_axes=output_sizes) + return blockwise( + func, + out_ind, + *blockwise_args, + dtype=dtype, + concatenate=True, + new_axes=output_sizes + ) -def apply_array_ufunc(func, *args, dask='forbidden'): +def apply_array_ufunc(func, *args, dask="forbidden"): """Apply a ndarray level function over ndarray objects.""" if any(isinstance(arg, dask_array_type) for arg in args): - if dask == 'forbidden': - raise ValueError('apply_ufunc encountered a dask array on an ' - 'argument, but handling for dask arrays has not ' - 'been enabled. Either set the ``dask`` argument ' - 'or load your data into memory first with ' - '``.load()`` or ``.compute()``') - elif dask == 'parallelized': - raise ValueError("cannot use dask='parallelized' for apply_ufunc " - 'unless at least one input is an xarray object') - elif dask == 'allowed': + if dask == "forbidden": + raise ValueError( + "apply_ufunc encountered a dask array on an " + "argument, but handling for dask arrays has not " + "been enabled. Either set the ``dask`` argument " + "or load your data into memory first with " + "``.load()`` or ``.compute()``" + ) + elif dask == "parallelized": + raise ValueError( + "cannot use dask='parallelized' for apply_ufunc " + "unless at least one input is an xarray object" + ) + elif dask == "allowed": pass else: - raise ValueError('unknown setting for dask array handling: {}' - .format(dask)) + raise ValueError("unknown setting for dask array handling: {}".format(dask)) return func(*args) @@ -687,12 +752,12 @@ def apply_ufunc( output_core_dims: Optional[Sequence[Sequence]] = ((),), exclude_dims: AbstractSet = frozenset(), vectorize: bool = False, - join: str = 'exact', - dataset_join: str = 'exact', + join: str = "exact", + dataset_join: str = "exact", dataset_fill_value: object = _NO_FILL_VALUE, keep_attrs: bool = False, kwargs: Mapping = None, - dask: str = 'forbidden', + dask: str = "forbidden", output_dtypes: Sequence = None, output_sizes: Mapping[Any, int] = None ) -> Any: @@ -904,9 +969,10 @@ def earth_mover_distance(first_samples, input_core_dims = ((),) * (len(args)) elif len(input_core_dims) != len(args): raise ValueError( - 'input_core_dims must be None or a tuple with the length same to ' - 'the number of arguments. Given input_core_dims: {}, ' - 'number of args: {}.'.format(input_core_dims, len(args))) + "input_core_dims must be None or a tuple with the length same to " + "the number of arguments. Given input_core_dims: {}, " + "number of args: {}.".format(input_core_dims, len(args)) + ) if kwargs is None: kwargs = {} @@ -914,8 +980,10 @@ def earth_mover_distance(first_samples, signature = _UFuncSignature(input_core_dims, output_core_dims) if exclude_dims and not exclude_dims <= signature.all_core_dims: - raise ValueError('each dimension in `exclude_dims` must also be a ' - 'core dimension in the function signature') + raise ValueError( + "each dimension in `exclude_dims` must also be a " + "core dimension in the function signature" + ) if kwargs: func = functools.partial(func, **kwargs) @@ -923,50 +991,63 @@ def earth_mover_distance(first_samples, if vectorize: if signature.all_core_dims: # we need the signature argument - if LooseVersion(np.__version__) < '1.12': # pragma: no cover + if LooseVersion(np.__version__) < "1.12": # pragma: no cover raise NotImplementedError( - 'numpy 1.12 or newer required when using vectorize=True ' - 'in xarray.apply_ufunc with non-scalar output core ' - 'dimensions.') - func = np.vectorize(func, - otypes=output_dtypes, - signature=signature.to_gufunc_string()) + "numpy 1.12 or newer required when using vectorize=True " + "in xarray.apply_ufunc with non-scalar output core " + "dimensions." + ) + func = np.vectorize( + func, otypes=output_dtypes, signature=signature.to_gufunc_string() + ) else: func = np.vectorize(func, otypes=output_dtypes) - variables_vfunc = functools.partial(apply_variable_ufunc, func, - signature=signature, - exclude_dims=exclude_dims, - keep_attrs=keep_attrs, - dask=dask, - output_dtypes=output_dtypes, - output_sizes=output_sizes) + variables_vfunc = functools.partial( + apply_variable_ufunc, + func, + signature=signature, + exclude_dims=exclude_dims, + keep_attrs=keep_attrs, + dask=dask, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + ) if any(isinstance(a, GroupBy) for a in args): - this_apply = functools.partial(apply_ufunc, func, - input_core_dims=input_core_dims, - output_core_dims=output_core_dims, - exclude_dims=exclude_dims, - join=join, - dataset_join=dataset_join, - dataset_fill_value=dataset_fill_value, - keep_attrs=keep_attrs, - dask=dask) + this_apply = functools.partial( + apply_ufunc, + func, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + exclude_dims=exclude_dims, + join=join, + dataset_join=dataset_join, + dataset_fill_value=dataset_fill_value, + keep_attrs=keep_attrs, + dask=dask, + ) return apply_groupby_func(this_apply, *args) elif any(is_dict_like(a) for a in args): - return apply_dataset_vfunc(variables_vfunc, *args, - signature=signature, - join=join, - exclude_dims=exclude_dims, - dataset_join=dataset_join, - fill_value=dataset_fill_value, - keep_attrs=keep_attrs) + return apply_dataset_vfunc( + variables_vfunc, + *args, + signature=signature, + join=join, + exclude_dims=exclude_dims, + dataset_join=dataset_join, + fill_value=dataset_fill_value, + keep_attrs=keep_attrs + ) elif any(isinstance(a, DataArray) for a in args): - return apply_dataarray_vfunc(variables_vfunc, *args, - signature=signature, - join=join, - exclude_dims=exclude_dims, - keep_attrs=keep_attrs) + return apply_dataarray_vfunc( + variables_vfunc, + *args, + signature=signature, + join=join, + exclude_dims=exclude_dims, + keep_attrs=keep_attrs + ) elif any(isinstance(a, Variable) for a in args): return variables_vfunc(*args) else: @@ -1011,21 +1092,23 @@ def dot(*arrays, dims=None, **kwargs): from .variable import Variable if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): - raise TypeError('Only xr.DataArray and xr.Variable are supported.' - 'Given {}.'.format([type(arr) for arr in arrays])) + raise TypeError( + "Only xr.DataArray and xr.Variable are supported." + "Given {}.".format([type(arr) for arr in arrays]) + ) if len(arrays) == 0: - raise TypeError('At least one array should be given.') + raise TypeError("At least one array should be given.") if isinstance(dims, str): - dims = (dims, ) + dims = (dims,) common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) all_dims = [] for arr in arrays: all_dims += [d for d in arr.dims if d not in all_dims] - einsum_axes = 'abcdefghijklmnopqrstuvwxyz' + einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} if dims is None: @@ -1038,40 +1121,49 @@ def dot(*arrays, dims=None, **kwargs): dims = tuple(dims) # make dims a tuple # dimensions to be parallelized - broadcast_dims = tuple(d for d in all_dims - if d in common_dims and d not in dims) - input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] - for arr in arrays] - output_core_dims = [tuple(d for d in all_dims if d not in - dims + broadcast_dims)] + broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims) + input_core_dims = [ + [d for d in arr.dims if d not in broadcast_dims] for arr in arrays + ] + output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)] # older dask than 0.17.4, we use tensordot if possible. if isinstance(arr.data, dask_array_type): import dask - if LooseVersion(dask.__version__) < LooseVersion('0.17.4'): + + if LooseVersion(dask.__version__) < LooseVersion("0.17.4"): if len(broadcast_dims) == 0 and len(arrays) == 2: - axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] - for arr in arrays] - return apply_ufunc(duck_array_ops.tensordot, *arrays, - dask='allowed', - input_core_dims=input_core_dims, - output_core_dims=output_core_dims, - kwargs={'axes': axes}) + axes = [ + [arr.get_axis_num(d) for d in arr.dims if d in dims] + for arr in arrays + ] + return apply_ufunc( + duck_array_ops.tensordot, + *arrays, + dask="allowed", + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + kwargs={"axes": axes} + ) # construct einsum subscripts, such as '...abc,...ab->...c' # Note: input_core_dims are always moved to the last position - subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds - in input_core_dims] - subscripts = ','.join(subscripts_list) - subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) + subscripts_list = [ + "..." + "".join([dim_map[d] for d in ds]) for ds in input_core_dims + ] + subscripts = ",".join(subscripts_list) + subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]]) # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for apply_ufunc to work. func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) - result = apply_ufunc(func, *arrays, - input_core_dims=input_core_dims, - output_core_dims=output_core_dims, - dask='allowed') + result = apply_ufunc( + func, + *arrays, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + dask="allowed" + ) return result.transpose(*[d for d in all_dims if d in result.dims]) @@ -1110,8 +1202,12 @@ def where(cond, x, y): Dataset.where, DataArray.where : equivalent methods """ # alignment for three arguments is complicated, so don't support it yet - return apply_ufunc(duck_array_ops.where, - cond, x, y, - join='exact', - dataset_join='exact', - dask='allowed') + return apply_ufunc( + duck_array_ops.where, + cond, + x, + y, + join="exact", + dataset_join="exact", + dask="allowed", + ) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index a6570525cc5..19609308e78 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -9,9 +9,19 @@ from .variable import concat as concat_vars -def concat(objs, dim=None, data_vars='all', coords='different', - compat='equals', positions=None, indexers=None, mode=None, - concat_over=None, fill_value=dtypes.NA, join='outer'): +def concat( + objs, + dim=None, + data_vars="all", + coords="different", + compat="equals", + positions=None, + indexers=None, + mode=None, + concat_over=None, + fill_value=dtypes.NA, + join="outer", +): """Concatenate xarray objects along a new or existing dimension. Parameters @@ -96,37 +106,49 @@ def concat(objs, dim=None, data_vars='all', coords='different', try: first_obj, objs = utils.peek_at(objs) except StopIteration: - raise ValueError('must supply at least one object to concatenate') + raise ValueError("must supply at least one object to concatenate") if dim is None: - warnings.warn('the `dim` argument to `concat` will be required ' - 'in a future version of xarray; for now, setting it to ' - "the old default of 'concat_dim'", - FutureWarning, stacklevel=2) - dim = 'concat_dims' + warnings.warn( + "the `dim` argument to `concat` will be required " + "in a future version of xarray; for now, setting it to " + "the old default of 'concat_dim'", + FutureWarning, + stacklevel=2, + ) + dim = "concat_dims" if indexers is not None: # pragma: no cover - warnings.warn('indexers has been renamed to positions; the alias ' - 'will be removed in a future version of xarray', - FutureWarning, stacklevel=2) + warnings.warn( + "indexers has been renamed to positions; the alias " + "will be removed in a future version of xarray", + FutureWarning, + stacklevel=2, + ) positions = indexers if mode is not None: - raise ValueError('`mode` is no longer a valid argument to ' - 'xarray.concat; it has been split into the ' - '`data_vars` and `coords` arguments') + raise ValueError( + "`mode` is no longer a valid argument to " + "xarray.concat; it has been split into the " + "`data_vars` and `coords` arguments" + ) if concat_over is not None: - raise ValueError('`concat_over` is no longer a valid argument to ' - 'xarray.concat; it has been split into the ' - '`data_vars` and `coords` arguments') + raise ValueError( + "`concat_over` is no longer a valid argument to " + "xarray.concat; it has been split into the " + "`data_vars` and `coords` arguments" + ) if isinstance(first_obj, DataArray): f = _dataarray_concat elif isinstance(first_obj, Dataset): f = _dataset_concat else: - raise TypeError('can only concatenate xarray Dataset and DataArray ' - 'objects, got %s' % type(first_obj)) + raise TypeError( + "can only concatenate xarray Dataset and DataArray " + "objects, got %s" % type(first_obj) + ) return f(objs, dim, data_vars, coords, compat, positions, fill_value, join) @@ -140,9 +162,9 @@ def _calc_concat_dim_coord(dim): if isinstance(dim, str): coord = None elif not isinstance(dim, (DataArray, Variable)): - dim_name = getattr(dim, 'name', None) + dim_name = getattr(dim, "name", None) if dim_name is None: - dim_name = 'concat_dim' + dim_name = "concat_dim" coord = IndexVariable(dim_name, dim) dim = dim_name elif not isinstance(dim, DataArray): @@ -166,12 +188,11 @@ def _calc_concat_over(datasets, dim, data_vars, coords): if dim in datasets[0]: concat_over.add(dim) for ds in datasets: - concat_over.update(k for k, v in ds.variables.items() - if dim in v.dims) + concat_over.update(k for k, v in ds.variables.items() if dim in v.dims) def process_subset_opt(opt, subset): if isinstance(opt, str): - if opt == 'different': + if opt == "different": # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: @@ -196,48 +217,60 @@ def process_subset_opt(opt, subset): else: equals[k] = True - elif opt == 'all': - concat_over.update(set(getattr(datasets[0], subset)) - - set(datasets[0].dims)) - elif opt == 'minimal': + elif opt == "all": + concat_over.update( + set(getattr(datasets[0], subset)) - set(datasets[0].dims) + ) + elif opt == "minimal": pass else: raise ValueError("unexpected value for %s: %s" % (subset, opt)) else: - invalid_vars = [k for k in opt - if k not in getattr(datasets[0], subset)] + invalid_vars = [k for k in opt if k not in getattr(datasets[0], subset)] if invalid_vars: - if subset == 'coords': + if subset == "coords": raise ValueError( - 'some variables in coords are not coordinates on ' - 'the first dataset: %s' % (invalid_vars,)) + "some variables in coords are not coordinates on " + "the first dataset: %s" % (invalid_vars,) + ) else: raise ValueError( - 'some variables in data_vars are not data variables ' - 'on the first dataset: %s' % (invalid_vars,)) + "some variables in data_vars are not data variables " + "on the first dataset: %s" % (invalid_vars,) + ) concat_over.update(opt) - process_subset_opt(data_vars, 'data_vars') - process_subset_opt(coords, 'coords') + process_subset_opt(data_vars, "data_vars") + process_subset_opt(coords, "coords") return concat_over, equals -def _dataset_concat(datasets, dim, data_vars, coords, compat, positions, - fill_value=dtypes.NA, join='outer'): +def _dataset_concat( + datasets, + dim, + data_vars, + coords, + compat, + positions, + fill_value=dtypes.NA, + join="outer", +): """ Concatenate a sequence of datasets along a new or existing dimension """ from .dataset import Dataset - if compat not in ['equals', 'identical']: - raise ValueError("compat=%r invalid: must be 'equals' " - "or 'identical'" % compat) + if compat not in ["equals", "identical"]: + raise ValueError( + "compat=%r invalid: must be 'equals' " "or 'identical'" % compat + ) dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] - datasets = align(*datasets, join=join, copy=False, exclude=[dim], - fill_value=fill_value) + datasets = align( + *datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value + ) concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) @@ -260,22 +293,22 @@ def insert_result_variable(k, v): # check that global attributes and non-concatenated variables are fixed # across all datasets for ds in datasets[1:]: - if (compat == 'identical' and - not utils.dict_equiv(ds.attrs, result_attrs)): - raise ValueError('dataset global attributes not equal') + if compat == "identical" and not utils.dict_equiv(ds.attrs, result_attrs): + raise ValueError("dataset global attributes not equal") for k, v in ds.variables.items(): if k not in result_vars and k not in concat_over: - raise ValueError('encountered unexpected variable %r' % k) + raise ValueError("encountered unexpected variable %r" % k) elif (k in result_coord_names) != (k in ds.coords): - raise ValueError('%r is a coordinate in some datasets but not ' - 'others' % k) + raise ValueError( + "%r is a coordinate in some datasets but not " "others" % k + ) elif k in result_vars and k != dim: # Don't use Variable.identical as it internally invokes # Variable.equals, and we may already know the answer - if compat == 'identical' and not utils.dict_equiv( - v.attrs, result_vars[k].attrs): - raise ValueError( - 'variable %s not identical across datasets' % k) + if compat == "identical" and not utils.dict_equiv( + v.attrs, result_vars[k].attrs + ): + raise ValueError("variable %s not identical across datasets" % k) # Proceed with equals() try: @@ -285,8 +318,7 @@ def insert_result_variable(k, v): result_vars[k].load() is_equal = v.equals(result_vars[k]) if not is_equal: - raise ValueError( - 'variable %s not equal across datasets' % k) + raise ValueError("variable %s not equal across datasets" % k) # we've already verified everything is consistent; now, calculate # shared dimension sizes so we can expand the necessary variables @@ -305,8 +337,9 @@ def ensure_common_dims(vars): common_dims = (dim,) + common_dims for var, dim_len in zip(vars, dim_lengths): if var.dims != common_dims: - common_shape = tuple(non_concat_dims.get(d, dim_len) - for d in common_dims) + common_shape = tuple( + non_concat_dims.get(d, dim_len) for d in common_dims + ) var = var.set_dims(common_dims, common_shape) yield var @@ -328,25 +361,42 @@ def ensure_common_dims(vars): return result -def _dataarray_concat(arrays, dim, data_vars, coords, compat, - positions, fill_value=dtypes.NA, join='outer'): +def _dataarray_concat( + arrays, + dim, + data_vars, + coords, + compat, + positions, + fill_value=dtypes.NA, + join="outer", +): arrays = list(arrays) - if data_vars != 'all': - raise ValueError('data_vars is not a valid argument when ' - 'concatenating DataArray objects') + if data_vars != "all": + raise ValueError( + "data_vars is not a valid argument when " "concatenating DataArray objects" + ) datasets = [] for n, arr in enumerate(arrays): if n == 0: name = arr.name elif name != arr.name: - if compat == 'identical': - raise ValueError('array names not identical') + if compat == "identical": + raise ValueError("array names not identical") else: arr = arr.rename(name) datasets.append(arr._to_temp_dataset()) - ds = _dataset_concat(datasets, dim, data_vars, coords, compat, - positions, fill_value=fill_value, join=join) + ds = _dataset_concat( + datasets, + dim, + data_vars, + coords, + compat, + positions, + fill_value=fill_value, + join=join, + ) return arrays[0]._from_temp_dataset(ds, name) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 83455b4f776..38b4540cc4e 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -7,7 +7,10 @@ from . import formatting, indexing from .merge import ( - expand_and_merge_variables, merge_coords, merge_coords_for_inplace_math) + expand_and_merge_variables, + merge_coords, + merge_coords_for_inplace_math, +) from .utils import Frozen, ReprObject, either_dict_or_kwargs from .variable import Variable @@ -17,7 +20,7 @@ # Used as the key corresponding to a DataArray's variable when converting # arbitrary DataArray objects to datasets -_THIS_ARRAY = ReprObject('') +_THIS_ARRAY = ReprObject("") class AbstractCoordinates(collections.abc.Mapping): @@ -76,11 +79,13 @@ def to_index(self, ordered_dims=None): if ordered_dims is None: ordered_dims = self.dims elif set(ordered_dims) != set(self.dims): - raise ValueError('ordered_dims must match dims, but does not: ' - '{} vs {}'.format(ordered_dims, self.dims)) + raise ValueError( + "ordered_dims must match dims, but does not: " + "{} vs {}".format(ordered_dims, self.dims) + ) if len(ordered_dims) == 0: - raise ValueError('no valid index for a 0-dimensional object') + raise ValueError("no valid index for a 0-dimensional object") elif len(ordered_dims) == 1: (dim,) = ordered_dims return self._data.get_index(dim) @@ -90,9 +95,10 @@ def to_index(self, ordered_dims=None): return pd.MultiIndex.from_product(indexes, names=names) def update(self, other): - other_vars = getattr(other, 'variables', other) - coords = merge_coords([self.variables, other_vars], - priority_arg=1, indexes=self.indexes) + other_vars = getattr(other, "variables", other) + coords = merge_coords( + [self.variables, other_vars], priority_arg=1, indexes=self.indexes + ) self._update_coords(coords) def _merge_raw(self, other): @@ -101,8 +107,7 @@ def _merge_raw(self, other): variables = OrderedDict(self.variables) else: # don't align because we already called xarray.align - variables = expand_and_merge_variables( - [self.variables, other.variables]) + variables = expand_and_merge_variables([self.variables, other.variables]) return variables @contextmanager @@ -114,9 +119,11 @@ def _merge_inplace(self, other): # don't include indexes in priority_vars, because we didn't align # first priority_vars = OrderedDict( - kv for kv in self.variables.items() if kv[0] not in self.dims) + kv for kv in self.variables.items() if kv[0] not in self.dims + ) variables = merge_coords_for_inplace_math( - [self.variables, other.variables], priority_vars=priority_vars) + [self.variables, other.variables], priority_vars=priority_vars + ) yield self._update_coords(variables) @@ -147,7 +154,7 @@ def merge(self, other): if other is None: return self.to_dataset() else: - other_vars = getattr(other, 'variables', other) + other_vars = getattr(other, "variables", other) coords = expand_and_merge_variables([self.variables, other_vars]) return Dataset._from_vars_and_coord_names(coords, set(coords)) @@ -169,9 +176,11 @@ def _names(self): @property def variables(self): - return Frozen(OrderedDict((k, v) - for k, v in self._data.variables.items() - if k in self._names)) + return Frozen( + OrderedDict( + (k, v) for k, v in self._data.variables.items() if k in self._names + ) + ) def __getitem__(self, key): if key in self._data.data_vars: @@ -209,8 +218,11 @@ def __delitem__(self, key): def _ipython_key_completions_(self): """Provide method for the key-autocompletions in IPython. """ - return [key for key in self._data._ipython_key_completions_() - if key not in self._data.data_vars] + return [ + key + for key in self._data._ipython_key_completions_() + if key not in self._data.data_vars + ] class DataArrayCoordinates(AbstractCoordinates): @@ -237,8 +249,9 @@ def _update_coords(self, coords): coords_plus_data[_THIS_ARRAY] = self._data.variable dims = calculate_dimensions(coords_plus_data) if not set(dims) <= set(self.dims): - raise ValueError('cannot add coordinates with new dimensions to ' - 'a DataArray') + raise ValueError( + "cannot add coordinates with new dimensions to " "a DataArray" + ) self._data._coords = coords self._data._indexes = None @@ -248,8 +261,11 @@ def variables(self): def _to_dataset(self, shallow_copy=True): from .dataset import Dataset - coords = OrderedDict((k, v.copy(deep=False) if shallow_copy else v) - for k, v in self._data._coords.items()) + + coords = OrderedDict( + (k, v.copy(deep=False) if shallow_copy else v) + for k, v in self._data._coords.items() + ) return Dataset._from_vars_and_coord_names(coords, set(coords)) def to_dataset(self): @@ -269,7 +285,8 @@ class LevelCoordinatesSource(Mapping[Hashable, Any]): Used for attribute style lookup with AttrAccessMixin. Not returned directly by any public methods. """ - def __init__(self, data_object: 'Union[DataArray, Dataset]'): + + def __init__(self, data_object: "Union[DataArray, Dataset]"): self._data = data_object def __getitem__(self, key): @@ -295,13 +312,16 @@ def assert_coordinate_consistent(obj, coords): if k in coords and k in obj.coords: if not coords[k].equals(obj[k].variable): raise IndexError( - 'dimension coordinate {!r} conflicts between ' - 'indexed and indexing objects:\n{}\nvs.\n{}' - .format(k, obj[k], coords[k])) + "dimension coordinate {!r} conflicts between " + "indexed and indexing objects:\n{}\nvs.\n{}".format( + k, obj[k], coords[k] + ) + ) -def remap_label_indexers(obj, indexers=None, method=None, tolerance=None, - **indexers_kwargs): +def remap_label_indexers( + obj, indexers=None, method=None, tolerance=None, **indexers_kwargs +): """ Remap **indexers from obj.coords. If indexer is an instance of DataArray and it has coordinate, then this @@ -314,11 +334,13 @@ def remap_label_indexers(obj, indexers=None, method=None, tolerance=None, new_indexes: mapping of new dimensional-coordinate. """ from .dataarray import DataArray - indexers = either_dict_or_kwargs( - indexers, indexers_kwargs, 'remap_label_indexers') - v_indexers = {k: v.variable.data if isinstance(v, DataArray) else v - for k, v in indexers.items()} + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "remap_label_indexers") + + v_indexers = { + k: v.variable.data if isinstance(v, DataArray) else v + for k, v in indexers.items() + } pos_indexers, new_indexes = indexing.remap_label_indexers( obj, v_indexers, method=method, tolerance=tolerance @@ -330,8 +352,8 @@ def remap_label_indexers(obj, indexers=None, method=None, tolerance=None, elif isinstance(v, DataArray): # drop coordinates found in indexers since .sel() already # ensures alignments - coords = OrderedDict((k, v) for k, v in v._coords.items() - if k not in indexers) - pos_indexers[k] = DataArray(pos_indexers[k], - coords=coords, dims=v.dims) + coords = OrderedDict( + (k, v) for k, v in v._coords.items() if k not in indexers + ) + pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims) return pos_indexers, new_indexes diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 838222139c7..5d4ff849b57 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -18,8 +18,7 @@ # Used under the terms of Dask's license, see licenses/DASK_LICENSE. def _isin_kernel(element, test_elements, assume_unique=False): - values = np.in1d(element.ravel(), test_elements, - assume_unique=assume_unique) + values = np.in1d(element.ravel(), test_elements, assume_unique=assume_unique) return values.reshape(element.shape + (1,) * test_elements.ndim) def isin(element, test_elements, assume_unique=False, invert=False): @@ -27,20 +26,24 @@ def isin(element, test_elements, assume_unique=False, invert=False): test_elements = da.asarray(test_elements) element_axes = tuple(range(element.ndim)) test_axes = tuple(i + element.ndim for i in range(test_elements.ndim)) - mapped = blockwise(_isin_kernel, element_axes + test_axes, - element, element_axes, - test_elements, test_axes, - adjust_chunks={axis: lambda _: 1 - for axis in test_axes}, - dtype=bool, - assume_unique=assume_unique) + mapped = blockwise( + _isin_kernel, + element_axes + test_axes, + element, + element_axes, + test_elements, + test_axes, + adjust_chunks={axis: lambda _: 1 for axis in test_axes}, + dtype=bool, + assume_unique=assume_unique, + ) result = mapped.any(axis=test_axes) if invert: result = ~result return result -if LooseVersion(dask_version) > LooseVersion('0.19.2'): +if LooseVersion(dask_version) > LooseVersion("0.19.2"): gradient = da.gradient else: # pragma: no cover @@ -64,8 +67,9 @@ def validate_axis(axis, ndim): if not isinstance(axis, Integral): raise TypeError("Axis value must be an integer, got %s" % axis) if axis < -ndim or axis >= ndim: - raise AxisError("Axis %d is out of bounds for array of dimension " - "%d" % (axis, ndim)) + raise AxisError( + "Axis %d is out of bounds for array of dimension " "%d" % (axis, ndim) + ) if axis < 0: axis += ndim return axis @@ -85,7 +89,7 @@ def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs): """ block_loc = block_id[axis] if array_locs is not None: - coord = coord[array_locs[0][block_loc]:array_locs[1][block_loc]] + coord = coord[array_locs[0][block_loc] : array_locs[1][block_loc]] grad = np.gradient(x, coord, axis=axis, **grad_kwargs) return grad @@ -130,16 +134,18 @@ def gradient(f, *varargs, axis=None, **kwargs): for c in f.chunks[ax]: if np.min(c) < kwargs["edge_order"] + 1: raise ValueError( - 'Chunk size must be larger than edge_order + 1. ' - 'Minimum chunk for aixs {} is {}. Rechunk to ' - 'proceed.'.format(np.min(c), ax)) + "Chunk size must be larger than edge_order + 1. " + "Minimum chunk for aixs {} is {}. Rechunk to " + "proceed.".format(np.min(c), ax) + ) if np.isscalar(varargs[i]): array_locs = None else: if isinstance(varargs[i], da.Array): raise NotImplementedError( - 'dask array coordinated is not supported.') + "dask array coordinated is not supported." + ) # coordinate position for each block taking overlap into # account chunk = np.array(f.chunks[ax]) @@ -149,16 +155,18 @@ def gradient(f, *varargs, axis=None, **kwargs): array_loc_start[0] = 0 array_locs = (array_loc_start, array_loc_stop) - results.append(f.map_overlap( - _gradient_kernel, - dtype=f.dtype, - depth={j: 1 if j == ax else 0 for j in range(f.ndim)}, - boundary="none", - coord=varargs[i], - axis=ax, - array_locs=array_locs, - grad_kwargs=kwargs, - )) + results.append( + f.map_overlap( + _gradient_kernel, + dtype=f.dtype, + depth={j: 1 if j == ax else 0 for j in range(f.ndim)}, + boundary="none", + coord=varargs[i], + axis=ax, + array_locs=array_locs, + grad_kwargs=kwargs, + ) + ) if drop_result_list: results = results[0] diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 7e72c93da27..11fdb86e9b0 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -7,8 +7,9 @@ try: import dask import dask.array as da + # Note: dask has used `ghost` before 0.18.2 - if LooseVersion(dask.__version__) <= LooseVersion('0.18.2'): + if LooseVersion(dask.__version__) <= LooseVersion("0.18.2"): overlap = da.ghost.ghost trim_internal = da.ghost.trim_internal else: @@ -19,7 +20,7 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): - '''wrapper to apply bottleneck moving window funcs on dask arrays''' + """wrapper to apply bottleneck moving window funcs on dask arrays""" dtype, fill_value = dtypes.maybe_promote(a.dtype) a = a.astype(dtype) # inputs for overlap @@ -31,8 +32,9 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): # Create overlap array. ag = overlap(a, depth=depth, boundary=boundary) # apply rolling func - out = ag.map_blocks(moving_func, window, min_count=min_count, - axis=axis, dtype=a.dtype) + out = ag.map_blocks( + moving_func, window, min_count=min_count, axis=axis, dtype=a.dtype + ) # trim array result = trim_internal(out, depth) return result @@ -53,8 +55,9 @@ def rolling_window(a, axis, window, center, fill_value): "For window size %d, every chunk should be larger than %d, " "but the smallest chunk size is %d. Rechunk your array\n" "with a larger chunk size or a chunk size that\n" - "more evenly divides the shape of your array." % - (window, depth[axis], min(a.chunks[axis]))) + "more evenly divides the shape of your array." + % (window, depth[axis], min(a.chunks[axis])) + ) # Although dask.overlap pads values to boundaries of the array, # the size of the generated array is smaller than what we want @@ -78,7 +81,7 @@ def rolling_window(a, axis, window, center, fill_value): shape = list(a.shape) shape[axis] = pad_size chunks = list(a.chunks) - chunks[axis] = (pad_size, ) + chunks[axis] = (pad_size,) fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks) a = da.concatenate([fill_array, a], axis=axis) @@ -91,14 +94,14 @@ def rolling_window(a, axis, window, center, fill_value): def func(x, window, axis=-1): x = np.asarray(x) rolling = nputils._rolling_window(x, window, axis) - return rolling[(slice(None), ) * axis + (slice(offset, None), )] + return rolling[(slice(None),) * axis + (slice(offset, None),)] chunks = list(a.chunks) chunks.append(window) - out = ag.map_blocks(func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, - window=window, axis=axis) + out = ag.map_blocks( + func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, window=window, axis=axis + ) # crop boundary. - index = (slice(None),) * axis + (slice(drop_size, - drop_size + orig_shape[axis]), ) + index = (slice(None),) * axis + (slice(drop_size, drop_size + orig_shape[axis]),) return out[index] diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 70d11fe18ca..33be8d96e91 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3,34 +3,65 @@ import warnings from collections import OrderedDict from numbers import Number -from typing import (Any, Callable, Dict, Hashable, Iterable, List, Mapping, - Optional, Sequence, Tuple, Union, cast, overload, - TYPE_CHECKING) +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, + overload, + TYPE_CHECKING, +) import numpy as np import pandas as pd from ..plot.plot import _PlotMethods from . import ( - computation, dtypes, groupby, indexing, ops, pdcompat, resample, rolling, - utils) + computation, + dtypes, + groupby, + indexing, + ops, + pdcompat, + resample, + rolling, + utils, +) from .accessor_dt import DatetimeAccessor from .accessor_str import StringAccessor -from .alignment import (align, _broadcast_helper, - _get_broadcast_dims_map_common_coords, - reindex_like_indexers) +from .alignment import ( + align, + _broadcast_helper, + _get_broadcast_dims_map_common_coords, + reindex_like_indexers, +) from .common import AbstractArray, DataWithCoords from .coordinates import ( - DataArrayCoordinates, LevelCoordinatesSource, assert_coordinate_consistent, - remap_label_indexers) + DataArrayCoordinates, + LevelCoordinatesSource, + assert_coordinate_consistent, + remap_label_indexers, +) from .dataset import Dataset, merge_indexes, split_indexes from .formatting import format_item from .indexes import Indexes, default_indexes from .options import OPTIONS from .utils import _check_inplace, either_dict_or_kwargs, ReprObject from .variable import ( - IndexVariable, Variable, as_compatible_data, as_variable, - assert_unique_multiindex_level_names) + IndexVariable, + Variable, + as_compatible_data, + as_variable, + assert_unique_multiindex_level_names, +) if TYPE_CHECKING: try: @@ -48,38 +79,44 @@ def _infer_coords_and_dims( - shape, coords, dims -) -> 'Tuple[OrderedDict[Any, Variable], Tuple[Hashable, ...]]': + shape, coords, dims +) -> "Tuple[OrderedDict[Any, Variable], Tuple[Hashable, ...]]": """All the logic for creating a new DataArray""" - if (coords is not None and not utils.is_dict_like(coords) and - len(coords) != len(shape)): - raise ValueError('coords is not dict-like, but it has %s items, ' - 'which does not match the %s dimensions of the ' - 'data' % (len(coords), len(shape))) + if ( + coords is not None + and not utils.is_dict_like(coords) + and len(coords) != len(shape) + ): + raise ValueError( + "coords is not dict-like, but it has %s items, " + "which does not match the %s dimensions of the " + "data" % (len(coords), len(shape)) + ) if isinstance(dims, str): dims = (dims,) if dims is None: - dims = ['dim_%s' % n for n in range(len(shape))] + dims = ["dim_%s" % n for n in range(len(shape))] if coords is not None and len(coords) == len(shape): # try to infer dimensions from coords if utils.is_dict_like(coords): # deprecated in GH993, removed in GH1539 - raise ValueError('inferring DataArray dimensions from ' - 'dictionary like ``coords`` is no longer ' - 'supported. Use an explicit list of ' - '``dims`` instead.') + raise ValueError( + "inferring DataArray dimensions from " + "dictionary like ``coords`` is no longer " + "supported. Use an explicit list of " + "``dims`` instead." + ) for n, (dim, coord) in enumerate(zip(dims, coords)): - coord = as_variable(coord, - name=dims[n]).to_index_variable() + coord = as_variable(coord, name=dims[n]).to_index_variable() dims[n] = coord.name dims = tuple(dims) else: for d in dims: if not isinstance(d, str): - raise TypeError('dimension %s is not a string' % d) + raise TypeError("dimension %s is not a string" % d) new_coords = OrderedDict() # type: OrderedDict[Any, Variable] @@ -95,21 +132,26 @@ def _infer_coords_and_dims( sizes = dict(zip(dims, shape)) for k, v in new_coords.items(): if any(d not in dims for d in v.dims): - raise ValueError('coordinate %s has dimensions %s, but these ' - 'are not a subset of the DataArray ' - 'dimensions %s' % (k, v.dims, dims)) + raise ValueError( + "coordinate %s has dimensions %s, but these " + "are not a subset of the DataArray " + "dimensions %s" % (k, v.dims, dims) + ) for d, s in zip(v.dims, v.shape): if s != sizes[d]: - raise ValueError('conflicting sizes for dimension %r: ' - 'length %s on the data but length %s on ' - 'coordinate %r' % (d, sizes[d], s, k)) + raise ValueError( + "conflicting sizes for dimension %r: " + "length %s on the data but length %s on " + "coordinate %r" % (d, sizes[d], s, k) + ) if k in sizes and v.shape != (sizes[k],): - raise ValueError('coordinate %r is a DataArray dimension, but ' - 'it has shape %r rather than expected shape %r ' - 'matching the dimension size' - % (k, v.shape, (sizes[k],))) + raise ValueError( + "coordinate %r is a DataArray dimension, but " + "it has shape %r rather than expected shape %r " + "matching the dimension size" % (k, v.shape, (sizes[k],)) + ) assert_unique_multiindex_level_names(new_coords) @@ -117,10 +159,10 @@ def _infer_coords_and_dims( class _LocIndexer: - def __init__(self, data_array: 'DataArray'): + def __init__(self, data_array: "DataArray"): self.data_array = data_array - def __getitem__(self, key) -> 'DataArray': + def __getitem__(self, key) -> "DataArray": if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) @@ -139,7 +181,7 @@ def __setitem__(self, key, value) -> None: # Used as the key corresponding to a DataArray's variable when converting # arbitrary DataArray objects to datasets -_THIS_ARRAY = ReprObject('') +_THIS_ARRAY = ReprObject("") class DataArray(AbstractArray, DataWithCoords): @@ -180,23 +222,20 @@ class DataArray(AbstractArray, DataWithCoords): attrs : OrderedDict Dictionary for holding arbitrary metadata. """ + _groupby_cls = groupby.DataArrayGroupBy _rolling_cls = rolling.DataArrayRolling _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample - __default = ReprObject('') + __default = ReprObject("") dt = property(DatetimeAccessor) def __init__( self, data: Any, - coords: Union[ - Sequence[Tuple], - Mapping[Hashable, Any], - None, - ] = None, + coords: Union[Sequence[Tuple], Mapping[Hashable, Any], None] = None, dims: Union[Hashable, Sequence[Hashable], None] = None, name: Hashable = None, attrs: Mapping = None, @@ -248,11 +287,13 @@ def __init__( """ if encoding is not None: warnings.warn( - 'The `encoding` argument to `DataArray` is deprecated, and . ' - 'will be removed in 0.13. ' - 'Instead, specify the encoding when writing to disk or ' - 'set the `encoding` attribute directly.', - FutureWarning, stacklevel=2) + "The `encoding` argument to `DataArray` is deprecated, and . " + "will be removed in 0.13. " + "Instead, specify the encoding when writing to disk or " + "set the `encoding` attribute directly.", + FutureWarning, + stacklevel=2, + ) if fastpath: variable = data assert dims is None @@ -274,13 +315,13 @@ def __init__( coords = [data.items, data.major_axis, data.minor_axis] if dims is None: - dims = getattr(data, 'dims', getattr(coords, 'dims', None)) + dims = getattr(data, "dims", getattr(coords, "dims", None)) if name is None: - name = getattr(data, 'name', None) + name = getattr(data, "name", None) if attrs is None: - attrs = getattr(data, 'attrs', None) + attrs = getattr(data, "attrs", None) if encoding is None: - encoding = getattr(data, 'encoding', None) + encoding = getattr(data, "encoding", None) data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) @@ -305,7 +346,7 @@ def _replace( variable: Variable = None, coords=None, name: Optional[Hashable] = __default, - ) -> 'DataArray': + ) -> "DataArray": if variable is None: variable = self.variable if coords is None: @@ -315,28 +356,26 @@ def _replace( return type(self)(variable, coords, name=name, fastpath=True) def _replace_maybe_drop_dims( - self, - variable: Variable, - name: Optional[Hashable] = __default - ) -> 'DataArray': + self, variable: Variable, name: Optional[Hashable] = __default + ) -> "DataArray": if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() elif variable.dims == self.dims: # Shape has changed (e.g. from reduce(..., keepdims=True) new_sizes = dict(zip(self.dims, variable.shape)) - coords = OrderedDict((k, v) for k, v in self._coords.items() - if v.shape == tuple(new_sizes[d] - for d in v.dims)) + coords = OrderedDict( + (k, v) + for k, v in self._coords.items() + if v.shape == tuple(new_sizes[d] for d in v.dims) + ) else: allowed_dims = set(variable.dims) - coords = OrderedDict((k, v) for k, v in self._coords.items() - if set(v.dims) <= allowed_dims) + coords = OrderedDict( + (k, v) for k, v in self._coords.items() if set(v.dims) <= allowed_dims + ) return self._replace(variable, coords, name) - def _replace_indexes( - self, - indexes: Mapping[Hashable, Any] - ) -> 'DataArray': + def _replace_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": if not len(indexes): return self coords = self._coords.copy() @@ -354,14 +393,11 @@ def _replace_indexes( return obj def _to_temp_dataset(self) -> Dataset: - return self._to_dataset_whole(name=_THIS_ARRAY, - shallow_copy=False) + return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self, - dataset: Dataset, - name: Hashable = __default - ) -> 'DataArray': + self, dataset: Dataset, name: Hashable = __default + ) -> "DataArray": variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables return self._replace(variable, coords, name) @@ -374,26 +410,29 @@ def subset(dim, label): array.attrs = {} return array - variables = OrderedDict([(label, subset(dim, label)) - for label in self.get_index(dim)]) + variables = OrderedDict( + [(label, subset(dim, label)) for label in self.get_index(dim)] + ) coords = self.coords.to_dataset() if dim in coords: del coords[dim] return Dataset(variables, coords, self.attrs) def _to_dataset_whole( - self, - name: Hashable = None, - shallow_copy: bool = True + self, name: Hashable = None, shallow_copy: bool = True ) -> Dataset: if name is None: name = self.name if name is None: - raise ValueError('unable to convert unnamed DataArray to a ' - 'Dataset without providing an explicit name') + raise ValueError( + "unable to convert unnamed DataArray to a " + "Dataset without providing an explicit name" + ) if name in self.coords: - raise ValueError('cannot create a Dataset from a DataArray with ' - 'the same name as one of its coordinates') + raise ValueError( + "cannot create a Dataset from a DataArray with " + "the same name as one of its coordinates" + ) # use private APIs for speed: this is called by _to_temp_dataset(), # which is used in the guts of a lot of operations (e.g., reindex) variables = self._coords.copy() @@ -405,11 +444,7 @@ def _to_dataset_whole( dataset = Dataset._from_vars_and_coord_names(variables, coord_names) return dataset - def to_dataset( - self, - dim: Hashable = None, - name: Hashable = None, - ) -> Dataset: + def to_dataset(self, dim: Hashable = None, name: Hashable = None) -> Dataset: """Convert a DataArray to a Dataset. Parameters @@ -427,16 +462,19 @@ def to_dataset( dataset : Dataset """ if dim is not None and dim not in self.dims: - warnings.warn('the order of the arguments on DataArray.to_dataset ' - 'has changed; you now need to supply ``name`` as ' - 'a keyword argument', - FutureWarning, stacklevel=2) + warnings.warn( + "the order of the arguments on DataArray.to_dataset " + "has changed; you now need to supply ``name`` as " + "a keyword argument", + FutureWarning, + stacklevel=2, + ) name = dim dim = None if dim is not None: if name is not None: - raise TypeError('cannot supply both dim and name arguments') + raise TypeError("cannot supply both dim and name arguments") return self._to_dataset_split(dim) else: return self._to_dataset_whole(name) @@ -520,8 +558,10 @@ def dims(self) -> Tuple[Hashable, ...]: @dims.setter def dims(self, value): - raise AttributeError('you cannot assign dims on a DataArray. Use ' - '.rename() or .swap_dims() instead.') + raise AttributeError( + "you cannot assign dims on a DataArray. Use " + ".rename() or .swap_dims() instead." + ) def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: if utils.is_dict_like(key): @@ -531,7 +571,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: return dict(zip(self.dims, key)) @property - def _level_coords(self) -> 'OrderedDict[Any, Hashable]': + def _level_coords(self) -> "OrderedDict[Any, Hashable]": """Return a mapping of all MultiIndex levels and their corresponding coordinate name. """ @@ -553,11 +593,12 @@ def _getitem_coord(self, key): except KeyError: dim_sizes = dict(zip(self.dims, self.shape)) _, key, var = _get_virtual_variable( - self._coords, key, self._level_coords, dim_sizes) + self._coords, key, self._level_coords, dim_sizes + ) return self._replace_maybe_drop_dims(var, name=key) - def __getitem__(self, key: Any) -> 'DataArray': + def __getitem__(self, key: Any) -> "DataArray": if isinstance(key, str): return self._getitem_coord(key) else: @@ -575,8 +616,10 @@ def __setitem__(self, key: Any, value: Any) -> None: if isinstance(value, DataArray): assert_coordinate_consistent(value, obj.coords.variables) # DataArray key -> Variable key - key = {k: v.variable if isinstance(v, DataArray) else v - for k, v in self._item_key_to_dict(key).items()} + key = { + k: v.variable if isinstance(v, DataArray) else v + for k, v in self._item_key_to_dict(key).items() + } self.variable[key] = value def __delitem__(self, key: Any) -> None: @@ -592,8 +635,11 @@ def _attr_sources(self) -> List[Mapping[Hashable, Any]]: def _item_sources(self) -> List[Mapping[Hashable, Any]]: """List of places to look-up items for key-completion """ - return [self.coords, {d: self.coords[d] for d in self.dims}, - LevelCoordinatesSource(self)] + return [ + self.coords, + {d: self.coords[d] for d in self.dims}, + LevelCoordinatesSource(self), + ] def __contains__(self, key: Any) -> bool: return key in self.data @@ -605,7 +651,7 @@ def loc(self) -> _LocIndexer: return _LocIndexer(self) @property - def attrs(self) -> 'OrderedDict[Any, Any]': + def attrs(self) -> "OrderedDict[Any, Any]": """Dictionary storing arbitrary metadata with this array.""" return self.variable.attrs @@ -615,7 +661,7 @@ def attrs(self, value: Mapping[Hashable, Any]) -> None: self.variable.attrs = value # type: ignore @property - def encoding(self) -> 'OrderedDict[Any, Any]': + def encoding(self) -> "OrderedDict[Any, Any]": """Dictionary of format-specific settings for how this array should be serialized.""" return self.variable.encoding @@ -638,10 +684,12 @@ def coords(self) -> DataArrayCoordinates: """ return DataArrayCoordinates(self) - def reset_coords(self, - names: Union[Iterable[Hashable], Hashable, None] = None, - drop: bool = False, inplace: bool = None - ) -> Union[None, 'DataArray', Dataset]: + def reset_coords( + self, + names: Union[Iterable[Hashable], Hashable, None] = None, + drop: bool = False, + inplace: bool = None, + ) -> Union[None, "DataArray", Dataset]: """Given names of coordinates, reset them to become variables. Parameters @@ -663,8 +711,10 @@ def reset_coords(self, """ inplace = _check_inplace(inplace) if inplace and not drop: - raise ValueError('cannot reset coordinates in-place on a ' - 'DataArray without ``drop == True``') + raise ValueError( + "cannot reset coordinates in-place on a " + "DataArray without ``drop == True``" + ) if names is None: names = set(self.coords) - set(self.dims) dataset = self.coords.to_dataset().reset_coords(names, drop) @@ -676,8 +726,9 @@ def reset_coords(self, return self._replace(coords=dataset._variables) else: if self.name is None: - raise ValueError('cannot reset_coords with drop=False ' - 'on an unnamed DataArrray') + raise ValueError( + "cannot reset_coords with drop=False " "on an unnamed DataArrray" + ) dataset[self.name] = self.variable return dataset @@ -713,7 +764,7 @@ def _dask_finalize(results, func, args, name): coords = ds._variables return DataArray(variable, coords, name=name, fastpath=True) - def load(self, **kwargs) -> 'DataArray': + def load(self, **kwargs) -> "DataArray": """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. @@ -737,7 +788,7 @@ def load(self, **kwargs) -> 'DataArray': self._coords = new._coords return self - def compute(self, **kwargs) -> 'DataArray': + def compute(self, **kwargs) -> "DataArray": """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. The original is left unaltered. @@ -759,7 +810,7 @@ def compute(self, **kwargs) -> 'DataArray': new = self.copy(deep=False) return new.load(**kwargs) - def persist(self, **kwargs) -> 'DataArray': + def persist(self, **kwargs) -> "DataArray": """ Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in @@ -778,11 +829,7 @@ def persist(self, **kwargs) -> 'DataArray': ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy( - self, - deep: bool = True, - data: Any = None, - ) -> 'DataArray': + def copy(self, deep: bool = True, data: Any = None) -> "DataArray": """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -853,14 +900,13 @@ def copy( pandas.DataFrame.copy """ variable = self.variable.copy(deep=deep, data=data) - coords = OrderedDict((k, v.copy(deep=deep)) - for k, v in self._coords.items()) + coords = OrderedDict((k, v.copy(deep=deep)) for k, v in self._coords.items()) return self._replace(variable, coords) - def __copy__(self) -> 'DataArray': + def __copy__(self) -> "DataArray": return self.copy(deep=False) - def __deepcopy__(self, memo=None) -> 'DataArray': + def __deepcopy__(self, memo=None) -> "DataArray": # memo does nothing but is required for compatibility with # copy.deepcopy return self.copy(deep=True) @@ -885,10 +931,10 @@ def chunk( Tuple[Tuple[Number, ...], ...], Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]], ] = None, - name_prefix: str = 'xarray-', + name_prefix: str = "xarray-", token: str = None, - lock: bool = False - ) -> 'DataArray': + lock: bool = False, + ) -> "DataArray": """Coerce this array's data into a dask arrays with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -919,8 +965,9 @@ def chunk( if isinstance(chunks, (tuple, list)): chunks = dict(zip(self.dims, chunks)) - ds = self._to_temp_dataset().chunk(chunks, name_prefix=name_prefix, - token=token, lock=lock) + ds = self._to_temp_dataset().chunk( + chunks, name_prefix=name_prefix, token=token, lock=lock + ) return self._from_temp_dataset(ds) def isel( @@ -928,7 +975,7 @@ def isel( indexers: Mapping[Hashable, Any] = None, drop: bool = False, **indexers_kwargs: Any - ) -> 'DataArray': + ) -> "DataArray": """Return a new DataArray whose data is given by integer indexing along the specified dimension(s). @@ -937,7 +984,7 @@ def isel( Dataset.isel DataArray.sel """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers) return self._from_temp_dataset(ds) @@ -948,7 +995,7 @@ def sel( tolerance=None, drop: bool = False, **indexers_kwargs: Any - ) -> 'DataArray': + ) -> "DataArray": """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). @@ -971,11 +1018,15 @@ def sel( """ ds = self._to_temp_dataset().sel( - indexers=indexers, drop=drop, method=method, tolerance=tolerance, - **indexers_kwargs) + indexers=indexers, + drop=drop, + method=method, + tolerance=tolerance, + **indexers_kwargs + ) return self._from_temp_dataset(ds) - def isel_points(self, dim='points', **indexers) -> 'DataArray': + def isel_points(self, dim="points", **indexers) -> "DataArray": """Return a new DataArray whose data is given by pointwise integer indexing along the specified dimension(s). @@ -986,8 +1037,9 @@ def isel_points(self, dim='points', **indexers) -> 'DataArray': ds = self._to_temp_dataset().isel_points(dim=dim, **indexers) return self._from_temp_dataset(ds) - def sel_points(self, dim='points', method=None, tolerance=None, - **indexers) -> 'DataArray': + def sel_points( + self, dim="points", method=None, tolerance=None, **indexers + ) -> "DataArray": """Return a new DataArray whose dataset is given by pointwise selection of index labels along the specified dimension(s). @@ -996,12 +1048,13 @@ def sel_points(self, dim='points', method=None, tolerance=None, Dataset.sel_points """ ds = self._to_temp_dataset().sel_points( - dim=dim, method=method, tolerance=tolerance, **indexers) + dim=dim, method=method, tolerance=tolerance, **indexers + ) return self._from_temp_dataset(ds) - def broadcast_like(self, - other: Union['DataArray', Dataset], - exclude: Iterable[Hashable] = None) -> 'DataArray': + def broadcast_like( + self, other: Union["DataArray", Dataset], exclude: Iterable[Hashable] = None + ) -> "DataArray": """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -1057,16 +1110,20 @@ def broadcast_like(self, exclude = set() else: exclude = set(exclude) - args = align(other, self, join='outer', copy=False, exclude=exclude) + args = align(other, self, join="outer", copy=False, exclude=exclude) - dims_map, common_coords = _get_broadcast_dims_map_common_coords( - args, exclude) + dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) return _broadcast_helper(args[1], exclude, dims_map, common_coords) - def reindex_like(self, other: Union['DataArray', Dataset], - method: str = None, tolerance=None, - copy: bool = True, fill_value=dtypes.NA) -> 'DataArray': + def reindex_like( + self, + other: Union["DataArray", Dataset], + method: str = None, + tolerance=None, + copy: bool = True, + fill_value=dtypes.NA, + ) -> "DataArray": """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -1120,10 +1177,15 @@ def reindex_like(self, other: Union['DataArray', Dataset], fill_value=fill_value, ) - def reindex(self, indexers: Mapping[Hashable, Any] = None, - method: str = None, tolerance=None, - copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any - ) -> 'DataArray': + def reindex( + self, + indexers: Mapping[Hashable, Any] = None, + method: str = None, + tolerance=None, + copy: bool = True, + fill_value=dtypes.NA, + **indexers_kwargs: Any + ) -> "DataArray": """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -1169,17 +1231,24 @@ def reindex(self, indexers: Mapping[Hashable, Any] = None, DataArray.reindex_like align """ - indexers = either_dict_or_kwargs( - indexers, indexers_kwargs, 'reindex') + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") ds = self._to_temp_dataset().reindex( - indexers=indexers, method=method, tolerance=tolerance, copy=copy, - fill_value=fill_value) + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) return self._from_temp_dataset(ds) - def interp(self, coords: Mapping[Hashable, Any] = None, - method: str = 'linear', assume_sorted: bool = False, - kwargs: Mapping[str, Any] = None, - **coords_kwargs: Any) -> 'DataArray': + def interp( + self, + coords: Mapping[Hashable, Any] = None, + method: str = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] = None, + **coords_kwargs: Any + ) -> "DataArray": """ Multidimensional interpolation of variables. coords : dict, optional @@ -1223,17 +1292,27 @@ def interp(self, coords: Mapping[Hashable, Any] = None, Coordinates: x float64 0.5 """ - if self.dtype.kind not in 'uifc': - raise TypeError('interp only works for a numeric type array. ' - 'Given {}.'.format(self.dtype)) + if self.dtype.kind not in "uifc": + raise TypeError( + "interp only works for a numeric type array. " + "Given {}.".format(self.dtype) + ) ds = self._to_temp_dataset().interp( - coords, method=method, kwargs=kwargs, assume_sorted=assume_sorted, - **coords_kwargs) + coords, + method=method, + kwargs=kwargs, + assume_sorted=assume_sorted, + **coords_kwargs + ) return self._from_temp_dataset(ds) - def interp_like(self, other: Union['DataArray', Dataset], - method: str = 'linear', assume_sorted: bool = False, - kwargs: Mapping[str, Any] = None) -> 'DataArray': + def interp_like( + self, + other: Union["DataArray", Dataset], + method: str = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] = None, + ) -> "DataArray": """Interpolate this object onto the coordinates of another object, filling out of range values with NaN. @@ -1272,19 +1351,21 @@ def interp_like(self, other: Union['DataArray', Dataset], DataArray.interp DataArray.reindex_like """ - if self.dtype.kind not in 'uifc': - raise TypeError('interp only works for a numeric type array. ' - 'Given {}.'.format(self.dtype)) + if self.dtype.kind not in "uifc": + raise TypeError( + "interp only works for a numeric type array. " + "Given {}.".format(self.dtype) + ) ds = self._to_temp_dataset().interp_like( - other, method=method, kwargs=kwargs, assume_sorted=assume_sorted) + other, method=method, kwargs=kwargs, assume_sorted=assume_sorted + ) return self._from_temp_dataset(ds) def rename( self, - new_name_or_name_dict: - Union[Hashable, Mapping[Hashable, Hashable]] = None, + new_name_or_name_dict: Union[Hashable, Mapping[Hashable, Hashable]] = None, **names: Hashable - ) -> 'DataArray': + ) -> "DataArray": """Returns a new DataArray with renamed coordinates or a new name. Parameters @@ -1309,17 +1390,17 @@ def rename( DataArray.swap_dims """ if names or utils.is_dict_like(new_name_or_name_dict): - new_name_or_name_dict = cast(Mapping[Hashable, Hashable], - new_name_or_name_dict) - name_dict = either_dict_or_kwargs( - new_name_or_name_dict, names, 'rename') + new_name_or_name_dict = cast( + Mapping[Hashable, Hashable], new_name_or_name_dict + ) + name_dict = either_dict_or_kwargs(new_name_or_name_dict, names, "rename") dataset = self._to_temp_dataset().rename(name_dict) return self._from_temp_dataset(dataset) else: new_name_or_name_dict = cast(Hashable, new_name_or_name_dict) return self._replace(name=new_name_or_name_dict) - def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> 'DataArray': + def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": """Returns a new DataArray with swapped dimensions. Parameters @@ -1343,9 +1424,12 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> 'DataArray': ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) - def expand_dims(self, dim: Union[None, Hashable, Sequence[Hashable], - Mapping[Hashable, Any]] = None, - axis=None, **dim_kwargs: Any) -> 'DataArray': + def expand_dims( + self, + dim: Union[None, Hashable, Sequence[Hashable], Mapping[Hashable, Any]] = None, + axis=None, + **dim_kwargs: Any + ) -> "DataArray": """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -1384,11 +1468,12 @@ def expand_dims(self, dim: Union[None, Hashable, Sequence[Hashable], This object, but with an additional dimension(s). """ if isinstance(dim, int): - raise TypeError('dim should be hashable or sequence/mapping of ' - 'hashables') + raise TypeError( + "dim should be hashable or sequence/mapping of " "hashables" + ) elif isinstance(dim, Sequence) and not isinstance(dim, str): if len(dim) != len(set(dim)): - raise ValueError('dims should not contain duplicate values.') + raise ValueError("dims should not contain duplicate values.") dim = OrderedDict((d, 1) for d in dim) elif dim is not None and not isinstance(dim, Mapping): dim = OrderedDict(((cast(Hashable, dim), 1),)) @@ -1403,7 +1488,7 @@ def expand_dims(self, dim: Union[None, Hashable, Sequence[Hashable], raise ValueError("dim_kwargs isn't available for python <3.6") dim_kwargs = OrderedDict(dim_kwargs) - dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims') + dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) @@ -1413,7 +1498,7 @@ def set_index( append: bool = False, inplace: bool = None, **indexes_kwargs: Union[Hashable, Sequence[Hashable]] - ) -> Optional['DataArray']: + ) -> Optional["DataArray"]: """Set DataArray (multi-)indexes using one or more existing coordinates. @@ -1444,7 +1529,7 @@ def set_index( DataArray.reset_index """ inplace = _check_inplace(inplace) - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") coords, _ = merge_indexes(indexes, self._coords, set(), append=append) if inplace: self._coords = coords @@ -1456,8 +1541,8 @@ def reset_index( self, dims_or_levels: Union[Hashable, Sequence[Hashable]], drop: bool = False, - inplace: bool = None - ) -> Optional['DataArray']: + inplace: bool = None, + ) -> Optional["DataArray"]: """Reset the specified index(es) or multi-index level(s). Parameters @@ -1483,8 +1568,9 @@ def reset_index( DataArray.set_index """ inplace = _check_inplace(inplace) - coords, _ = split_indexes(dims_or_levels, self._coords, set(), - self._level_coords, drop=drop) + coords, _ = split_indexes( + dims_or_levels, self._coords, set(), self._level_coords, drop=drop + ) if inplace: self._coords = coords return None @@ -1496,7 +1582,7 @@ def reorder_levels( dim_order: Mapping[Hashable, Sequence[int]] = None, inplace: bool = None, **dim_order_kwargs: Sequence[int] - ) -> Optional['DataArray']: + ) -> Optional["DataArray"]: """Rearrange index levels using input order. Parameters @@ -1519,16 +1605,14 @@ def reorder_levels( coordinates. If ``inplace == True``, return None. """ inplace = _check_inplace(inplace) - dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, - 'reorder_levels') + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") replace_coords = {} for dim, order in dim_order.items(): coord = self._coords[dim] index = coord.to_index() if not isinstance(index, pd.MultiIndex): raise ValueError("coordinate %r has no MultiIndex" % dim) - replace_coords[dim] = IndexVariable(coord.dims, - index.reorder_levels(order)) + replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order)) coords = self._coords.copy() coords.update(replace_coords) if inplace: @@ -1541,7 +1625,7 @@ def stack( self, dimensions: Mapping[Hashable, Sequence[Hashable]] = None, **dimensions_kwargs: Sequence[Hashable] - ) -> 'DataArray': + ) -> "DataArray": """ Stack any number of existing dimensions into a single new dimension. @@ -1587,8 +1671,9 @@ def stack( ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) return self._from_temp_dataset(ds) - def unstack(self, dim: Union[Hashable, Sequence[Hashable], None] = None - ) -> 'DataArray': + def unstack( + self, dim: Union[Hashable, Sequence[Hashable], None] = None + ) -> "DataArray": """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -1700,9 +1785,7 @@ def to_unstacked_dataset(self, dim, level=0): # unstacked dataset return Dataset(data_dict) - def transpose(self, - *dims: Hashable, - transpose_coords: bool = None) -> 'DataArray': + def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArray": """Return a new DataArray object with transposed dimensions. Parameters @@ -1731,9 +1814,10 @@ def transpose(self, """ if dims: if set(dims) ^ set(self.dims): - raise ValueError('arguments to transpose (%s) must be ' - 'permuted array dimensions (%s)' - % (dims, tuple(self.dims))) + raise ValueError( + "arguments to transpose (%s) must be " + "permuted array dimensions (%s)" % (dims, tuple(self.dims)) + ) variable = self.variable.transpose(*dims) if transpose_coords: @@ -1743,41 +1827,36 @@ def transpose(self, coords[name] = coord.variable.transpose(*coord_dims) return self._replace(variable, coords) else: - if transpose_coords is None \ - and any(self[c].ndim > 1 for c in self.coords): - warnings.warn('This DataArray contains multi-dimensional ' - 'coordinates. In the future, these coordinates ' - 'will be transposed as well unless you specify ' - 'transpose_coords=False.', - FutureWarning, stacklevel=2) + if transpose_coords is None and any(self[c].ndim > 1 for c in self.coords): + warnings.warn( + "This DataArray contains multi-dimensional " + "coordinates. In the future, these coordinates " + "will be transposed as well unless you specify " + "transpose_coords=False.", + FutureWarning, + stacklevel=2, + ) return self._replace(variable) @property - def T(self) -> 'DataArray': + def T(self) -> "DataArray": return self.transpose() # Drop coords @overload def drop( - self, - labels: Union[Hashable, Iterable[Hashable]], - *, - errors: str = 'raise' - ) -> 'DataArray': + self, labels: Union[Hashable, Iterable[Hashable]], *, errors: str = "raise" + ) -> "DataArray": ... # Drop index labels along dimension @overload # noqa: F811 def drop( - self, - labels: Any, # array-like - dim: Hashable, - *, - errors: str = 'raise' - ) -> 'DataArray': + self, labels: Any, dim: Hashable, *, errors: str = "raise" # array-like + ) -> "DataArray": ... - def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 + def drop(self, labels, dim=None, *, errors="raise"): # noqa: F811 """Drop coordinates or index labels from this DataArray. Parameters @@ -1800,8 +1879,9 @@ def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 ds = self._to_temp_dataset().drop(labels, dim, errors=errors) return self._from_temp_dataset(ds) - def dropna(self, dim: Hashable, how: str = 'any', - thresh: int = None) -> 'DataArray': + def dropna( + self, dim: Hashable, how: str = "any", thresh: int = None + ) -> "DataArray": """Returns a new array with dropped labels for missing values along the provided dimension. @@ -1823,7 +1903,7 @@ def dropna(self, dim: Hashable, how: str = 'any', ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh) return self._from_temp_dataset(ds) - def fillna(self, value: Any) -> 'DataArray': + def fillna(self, value: Any) -> "DataArray": """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -1843,15 +1923,21 @@ def fillna(self, value: Any) -> 'DataArray': DataArray """ if utils.is_dict_like(value): - raise TypeError('cannot provide fill value as a dictionary with ' - 'fillna on a DataArray') + raise TypeError( + "cannot provide fill value as a dictionary with " + "fillna on a DataArray" + ) out = ops.fillna(self, value) return out - def interpolate_na(self, dim=None, method: str = 'linear', - limit: int = None, - use_coordinate: Union[bool, str] = True, - **kwargs: Any) -> 'DataArray': + def interpolate_na( + self, + dim=None, + method: str = "linear", + limit: int = None, + use_coordinate: Union[bool, str] = True, + **kwargs: Any + ) -> "DataArray": """Interpolate values according to different methods. Parameters @@ -1891,10 +1977,17 @@ def interpolate_na(self, dim=None, method: str = 'linear', scipy.interpolate """ from .missing import interp_na - return interp_na(self, dim=dim, method=method, limit=limit, - use_coordinate=use_coordinate, **kwargs) - def ffill(self, dim: Hashable, limit: int = None) -> 'DataArray': + return interp_na( + self, + dim=dim, + method=method, + limit=limit, + use_coordinate=use_coordinate, + **kwargs + ) + + def ffill(self, dim: Hashable, limit: int = None) -> "DataArray": """Fill NaN values by propogating values forward *Requires bottleneck.* @@ -1915,9 +2008,10 @@ def ffill(self, dim: Hashable, limit: int = None) -> 'DataArray': DataArray """ from .missing import ffill + return ffill(self, dim, limit=limit) - def bfill(self, dim: Hashable, limit: int = None) -> 'DataArray': + def bfill(self, dim: Hashable, limit: int = None) -> "DataArray": """Fill NaN values by propogating values backward *Requires bottleneck.* @@ -1938,9 +2032,10 @@ def bfill(self, dim: Hashable, limit: int = None) -> 'DataArray': DataArray """ from .missing import bfill + return bfill(self, dim, limit=limit) - def combine_first(self, other: 'DataArray') -> 'DataArray': + def combine_first(self, other: "DataArray") -> "DataArray": """Combine two DataArray objects, with union of coordinates. This operation follows the normal broadcasting and alignment rules of @@ -1958,12 +2053,15 @@ def combine_first(self, other: 'DataArray') -> 'DataArray': """ return ops.fillna(self, other, join="outer") - def reduce(self, func: Callable[..., Any], - dim: Union[None, Hashable, Sequence[Hashable]] = None, - axis: Union[None, int, Sequence[int]] = None, - keep_attrs: bool = None, - keepdims: bool = False, - **kwargs: Any) -> 'DataArray': + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any + ) -> "DataArray": """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1997,11 +2095,10 @@ def reduce(self, func: Callable[..., Any], summarized data and the indicated dimension(s) removed. """ - var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, - **kwargs) + var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) return self._replace_maybe_drop_dims(var) - def to_pandas(self) -> Union['DataArray', pd.Series, pd.DataFrame]: + def to_pandas(self) -> Union["DataArray", pd.Series, pd.DataFrame]: """Convert this array into a pandas object with the same shape. The type of the returned object depends on the number of DataArray @@ -2018,22 +2115,23 @@ def to_pandas(self) -> Union['DataArray', pd.Series, pd.DataFrame]: """ # TODO: consolidate the info about pandas constructors and the # attributes that correspond to their indexes into a separate module? - constructors = {0: lambda x: x, - 1: pd.Series, - 2: pd.DataFrame, - 3: pdcompat.Panel} + constructors = { + 0: lambda x: x, + 1: pd.Series, + 2: pd.DataFrame, + 3: pdcompat.Panel, + } try: constructor = constructors[self.ndim] except KeyError: - raise ValueError('cannot convert arrays with %s dimensions into ' - 'pandas objects' % self.ndim) + raise ValueError( + "cannot convert arrays with %s dimensions into " + "pandas objects" % self.ndim + ) indexes = [self.get_index(dim) for dim in self.dims] return constructor(self.values, *indexes) - def to_dataframe( - self, - name: Hashable = None, - ) -> pd.DataFrame: + def to_dataframe(self, name: Hashable = None) -> pd.DataFrame: """Convert this array and its coordinates into a tidy pandas.DataFrame. The DataFrame is indexed by the Cartesian product of index coordinates @@ -2044,8 +2142,10 @@ def to_dataframe( if name is None: name = self.name if name is None: - raise ValueError('cannot convert an unnamed DataArray to a ' - 'DataFrame: use the ``name`` parameter') + raise ValueError( + "cannot convert an unnamed DataArray to a " + "DataFrame: use the ``name`` parameter" + ) dims = OrderedDict(zip(self.dims, self.shape)) # By using a unique name, we can convert a DataArray into a DataFrame @@ -2053,11 +2153,10 @@ def to_dataframe( # I would normally use unique_name = object() but that results in a # dataframe with columns in the wrong order, for reasons I have not # been able to debug (possibly a pandas bug?). - unique_name = '__unique_name_identifier_z98xfz98xugfg73ho__' + unique_name = "__unique_name_identifier_z98xfz98xugfg73ho__" ds = self._to_dataset_whole(name=unique_name) df = ds._to_dataframe(dims) - df.columns = [name if c == unique_name else c - for c in df.columns] + df.columns = [name if c == unique_name else c for c in df.columns] return df def to_series(self) -> pd.Series: @@ -2087,7 +2186,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: isnull = pd.isnull(values) return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) - def to_netcdf(self, *args, **kwargs) -> Optional['Delayed']: + def to_netcdf(self, *args, **kwargs) -> Optional["Delayed"]: """Write DataArray contents to a netCDF file. All parameters are passed directly to `xarray.Dataset.to_netcdf`. @@ -2136,13 +2235,13 @@ def to_dict(self, data: bool = True) -> dict: DataArray.from_dict """ d = self.variable.to_dict(data=data) - d.update({'coords': {}, 'name': self.name}) + d.update({"coords": {}, "name": self.name}) for k in self.coords: - d['coords'][k] = self.coords[k].variable.to_dict(data=data) + d["coords"][k] = self.coords[k].variable.to_dict(data=data) return d @classmethod - def from_dict(cls, d: dict) -> 'DataArray': + def from_dict(cls, d: dict) -> "DataArray": """ Convert a dictionary into an xarray.DataArray @@ -2174,27 +2273,29 @@ def from_dict(cls, d: dict) -> 'DataArray': Dataset.from_dict """ coords = None - if 'coords' in d: + if "coords" in d: try: - coords = OrderedDict([(k, (v['dims'], - v['data'], - v.get('attrs'))) - for k, v in d['coords'].items()]) + coords = OrderedDict( + [ + (k, (v["dims"], v["data"], v.get("attrs"))) + for k, v in d["coords"].items() + ] + ) except KeyError as e: raise ValueError( "cannot convert dict when coords are missing the key " - "'{dims_data}'".format(dims_data=str(e.args[0]))) + "'{dims_data}'".format(dims_data=str(e.args[0])) + ) try: - data = d['data'] + data = d["data"] except KeyError: raise ValueError("cannot convert dict without the key 'data''") else: - obj = cls(data, coords, d.get('dims'), d.get('name'), - d.get('attrs')) + obj = cls(data, coords, d.get("dims"), d.get("name"), d.get("attrs")) return obj @classmethod - def from_series(cls, series: pd.Series) -> 'DataArray': + def from_series(cls, series: pd.Series) -> "DataArray": """Convert a pandas.Series into an xarray.DataArray. If the series's index is a MultiIndex, it will be expanded into a @@ -2208,42 +2309,48 @@ def from_series(cls, series: pd.Series) -> 'DataArray': ds = Dataset.from_dataframe(df) return ds[name] - def to_cdms2(self) -> 'cdms2_Variable': + def to_cdms2(self) -> "cdms2_Variable": """Convert this array into a cdms2.Variable """ from ..convert import to_cdms2 + return to_cdms2(self) @classmethod - def from_cdms2(cls, variable: 'cdms2_Variable') -> 'DataArray': + def from_cdms2(cls, variable: "cdms2_Variable") -> "DataArray": """Convert a cdms2.Variable into an xarray.DataArray """ from ..convert import from_cdms2 + return from_cdms2(variable) - def to_iris(self) -> 'iris_Cube': + def to_iris(self) -> "iris_Cube": """Convert this array into a iris.cube.Cube """ from ..convert import to_iris + return to_iris(self) @classmethod - def from_iris(cls, cube: 'iris_Cube') -> 'DataArray': + def from_iris(cls, cube: "iris_Cube") -> "DataArray": """Convert a iris.cube.Cube into an xarray.DataArray """ from ..convert import from_iris + return from_iris(cube) - def _all_compat(self, other: 'DataArray', compat_str: str) -> bool: + def _all_compat(self, other: "DataArray", compat_str: str) -> bool: """Helper function for equals, broadcast_equals, and identical """ + def compat(x, y): return getattr(x.variable, compat_str)(y.variable) - return (utils.dict_equiv(self.coords, other.coords, compat=compat) and - compat(self, other)) + return utils.dict_equiv(self.coords, other.coords, compat=compat) and compat( + self, other + ) - def broadcast_equals(self, other: 'DataArray') -> bool: + def broadcast_equals(self, other: "DataArray") -> bool: """Two DataArrays are broadcast equal if they are equal after broadcasting them against each other such that they have the same dimensions. @@ -2254,11 +2361,11 @@ def broadcast_equals(self, other: 'DataArray') -> bool: DataArray.identical """ try: - return self._all_compat(other, 'broadcast_equals') + return self._all_compat(other, "broadcast_equals") except (TypeError, AttributeError): return False - def equals(self, other: 'DataArray') -> bool: + def equals(self, other: "DataArray") -> bool: """True if two DataArrays have the same dimensions, coordinates and values; otherwise False. @@ -2274,11 +2381,11 @@ def equals(self, other: 'DataArray') -> bool: DataArray.identical """ try: - return self._all_compat(other, 'equals') + return self._all_compat(other, "equals") except (TypeError, AttributeError): return False - def identical(self, other: 'DataArray') -> bool: + def identical(self, other: "DataArray") -> bool: """Like equals, but also checks the array name and attributes, and attributes on all coordinates. @@ -2288,8 +2395,7 @@ def identical(self, other: 'DataArray') -> bool: DataArray.equal """ try: - return (self.name == other.name and - self._all_compat(other, 'identical')) + return self.name == other.name and self._all_compat(other, "identical") except (TypeError, AttributeError): return False @@ -2298,13 +2404,13 @@ def identical(self, other: 'DataArray') -> bool: def _result_name(self, other: Any = None) -> Optional[Hashable]: # use the same naming heuristics as pandas: # https://github.com/ContinuumIO/blaze/issues/458#issuecomment-51936356 - other_name = getattr(other, 'name', self.__default_name) + other_name = getattr(other, "name", self.__default_name) if other_name is self.__default_name or other_name == self.name: return self.name else: return None - def __array_wrap__(self, obj, context=None) -> 'DataArray': + def __array_wrap__(self, obj, context=None) -> "DataArray": new_var = self.variable.__array_wrap__(obj, context) return self._replace(new_var) @@ -2317,36 +2423,36 @@ def __rmatmul__(self, other): return computation.dot(other, self) @staticmethod - def _unary_op(f: Callable[..., Any] - ) -> Callable[..., 'DataArray']: + def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]: @functools.wraps(f) def func(self, *args, **kwargs): - with np.errstate(all='ignore'): - return self.__array_wrap__(f(self.variable.data, *args, - **kwargs)) + with np.errstate(all="ignore"): + return self.__array_wrap__(f(self.variable.data, *args, **kwargs)) return func @staticmethod - def _binary_op(f: Callable[..., Any], - reflexive: bool = False, - join: str = None, # see xarray.align - **ignored_kwargs - ) -> Callable[..., 'DataArray']: + def _binary_op( + f: Callable[..., Any], + reflexive: bool = False, + join: str = None, # see xarray.align + **ignored_kwargs + ) -> Callable[..., "DataArray"]: @functools.wraps(f) def func(self, other): if isinstance(other, (Dataset, groupby.GroupBy)): return NotImplemented if isinstance(other, DataArray): - align_type = (OPTIONS['arithmetic_join'] - if join is None else join) + align_type = OPTIONS["arithmetic_join"] if join is None else join self, other = align(self, other, join=align_type, copy=False) - other_variable = getattr(other, 'variable', other) - other_coords = getattr(other, 'coords', None) + other_variable = getattr(other, "variable", other) + other_coords = getattr(other, "coords", None) - variable = (f(self.variable, other_variable) - if not reflexive - else f(other_variable, self.variable)) + variable = ( + f(self.variable, other_variable) + if not reflexive + else f(other_variable, self.variable) + ) coords = self.coords._merge_raw(other_coords) name = self._result_name(other) @@ -2355,26 +2461,27 @@ def func(self, other): return func @staticmethod - def _inplace_binary_op(f: Callable) -> Callable[..., 'DataArray']: + def _inplace_binary_op(f: Callable) -> Callable[..., "DataArray"]: @functools.wraps(f) def func(self, other): if isinstance(other, groupby.GroupBy): - raise TypeError('in-place operations between a DataArray and ' - 'a grouped object are not permitted') + raise TypeError( + "in-place operations between a DataArray and " + "a grouped object are not permitted" + ) # n.b. we can't align other to self (with other.reindex_like(self)) # because `other` may be converted into floats, which would cause # in-place arithmetic to fail unpredictably. Instead, we simply # don't support automatic alignment with in-place arithmetic. - other_coords = getattr(other, 'coords', None) - other_variable = getattr(other, 'variable', other) + other_coords = getattr(other, "coords", None) + other_variable = getattr(other, "variable", other) with self.coords._merge_inplace(other_coords): f(self.variable, other_variable) return self return func - def _copy_attrs_from(self, other: Union['DataArray', Dataset, Variable] - ) -> None: + def _copy_attrs_from(self, other: Union["DataArray", Dataset, Variable]) -> None: self.attrs = other.attrs @property @@ -2413,17 +2520,17 @@ def _title_for_slice(self, truncate: int = 50) -> str: one_dims = [] for dim, coord in self.coords.items(): if coord.size == 1: - one_dims.append('{dim} = {v}'.format( - dim=dim, v=format_item(coord.values))) + one_dims.append( + "{dim} = {v}".format(dim=dim, v=format_item(coord.values)) + ) - title = ', '.join(one_dims) + title = ", ".join(one_dims) if len(title) > truncate: - title = title[:(truncate - 3)] + '...' + title = title[: (truncate - 3)] + "..." return title - def diff(self, dim: Hashable, n: int = 1, label: Hashable = 'upper' - ) -> 'DataArray': + def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArray": """Calculate the n-th order discrete difference along given axis. Parameters @@ -2464,9 +2571,12 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = 'upper' ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) - def shift(self, shifts: Mapping[Hashable, int] = None, - fill_value: Any = dtypes.NA, **shifts_kwargs: int - ) -> 'DataArray': + def shift( + self, + shifts: Mapping[Hashable, int] = None, + fill_value: Any = dtypes.NA, + **shifts_kwargs: int + ) -> "DataArray": """Shift this array by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. Values shifted from @@ -2506,12 +2616,16 @@ def shift(self, shifts: Mapping[Hashable, int] = None, * x (x) int64 0 1 2 """ variable = self.variable.shift( - shifts=shifts, fill_value=fill_value, **shifts_kwargs) + shifts=shifts, fill_value=fill_value, **shifts_kwargs + ) return self._replace(variable=variable) - def roll(self, shifts: Mapping[Hashable, int] = None, - roll_coords: bool = None, - **shifts_kwargs: int) -> 'DataArray': + def roll( + self, + shifts: Mapping[Hashable, int] = None, + roll_coords: bool = None, + **shifts_kwargs: int + ) -> "DataArray": """Roll this array by an offset along one or more dimensions. Unlike shift, roll may rotate all variables, including coordinates @@ -2552,20 +2666,21 @@ def roll(self, shifts: Mapping[Hashable, int] = None, * x (x) int64 2 0 1 """ ds = self._to_temp_dataset().roll( - shifts=shifts, roll_coords=roll_coords, **shifts_kwargs) + shifts=shifts, roll_coords=roll_coords, **shifts_kwargs + ) return self._from_temp_dataset(ds) @property - def real(self) -> 'DataArray': + def real(self) -> "DataArray": return self._replace(self.variable.real) @property - def imag(self) -> 'DataArray': + def imag(self) -> "DataArray": return self._replace(self.variable.imag) - def dot(self, other: 'DataArray', - dims: Union[Hashable, Sequence[Hashable], None] = None - ) -> 'DataArray': + def dot( + self, other: "DataArray", dims: Union[Hashable, Sequence[Hashable], None] = None + ) -> "DataArray": """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -2606,16 +2721,19 @@ def dot(self, other: 'DataArray', ('x', 'y') """ if isinstance(other, Dataset): - raise NotImplementedError('dot products are not yet supported ' - 'with Dataset objects.') + raise NotImplementedError( + "dot products are not yet supported " "with Dataset objects." + ) if not isinstance(other, DataArray): - raise TypeError('dot only operates on DataArrays.') + raise TypeError("dot only operates on DataArrays.") return computation.dot(self, other, dims=dims) - def sortby(self, variables: Union[Hashable, 'DataArray', - Sequence[Union[Hashable, 'DataArray']]], - ascending: bool = True) -> 'DataArray': + def sortby( + self, + variables: Union[Hashable, "DataArray", Sequence[Union[Hashable, "DataArray"]]], + ascending: bool = True, + ) -> "DataArray": """Sort object by labels or values (along an axis). Sorts the dataarray, either along specified dimensions, @@ -2667,10 +2785,13 @@ def sortby(self, variables: Union[Hashable, 'DataArray', ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) - def quantile(self, q: Any, - dim: Union[Hashable, Sequence[Hashable], None] = None, - interpolation: str = 'linear', - keep_attrs: bool = None) -> 'DataArray': + def quantile( + self, + q: Any, + dim: Union[Hashable, Sequence[Hashable], None] = None, + interpolation: str = "linear", + keep_attrs: bool = None, + ) -> "DataArray": """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -2713,11 +2834,13 @@ def quantile(self, q: Any, """ ds = self._to_temp_dataset().quantile( - q, dim=dim, keep_attrs=keep_attrs, interpolation=interpolation) + q, dim=dim, keep_attrs=keep_attrs, interpolation=interpolation + ) return self._from_temp_dataset(ds) - def rank(self, dim: Hashable, pct: bool = False, keep_attrs: bool = None - ) -> 'DataArray': + def rank( + self, dim: Hashable, pct: bool = False, keep_attrs: bool = None + ) -> "DataArray": """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -2757,8 +2880,9 @@ def rank(self, dim: Hashable, pct: bool = False, keep_attrs: bool = None ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs) return self._from_temp_dataset(ds) - def differentiate(self, coord: Hashable, edge_order: int = 1, - datetime_unit: str = None) -> 'DataArray': + def differentiate( + self, coord: Hashable, edge_order: int = 1, datetime_unit: str = None + ) -> "DataArray": """ Differentiate the array with the second order accurate central differences. @@ -2809,12 +2933,12 @@ def differentiate(self, coord: Hashable, edge_order: int = 1, * x (x) float64 0.0 0.1 1.1 1.2 Dimensions without coordinates: y """ - ds = self._to_temp_dataset().differentiate( - coord, edge_order, datetime_unit) + ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit) return self._from_temp_dataset(ds) - def integrate(self, dim: Union[Hashable, Sequence[Hashable]], - datetime_unit: str = None) -> 'DataArray': + def integrate( + self, dim: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None + ) -> "DataArray": """ integrate the array with the trapezoidal rule. .. note:: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 52e4c0f82d3..a85f015cfa8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6,36 +6,77 @@ from distutils.version import LooseVersion from numbers import Number from pathlib import Path -from typing import (Any, Callable, DefaultDict, Dict, Hashable, Iterable, - Iterator, List, Mapping, MutableMapping, Optional, - Sequence, Set, Tuple, Union, cast, overload, - TYPE_CHECKING) +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Hashable, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, + overload, + TYPE_CHECKING, +) import numpy as np import pandas as pd import xarray as xr from ..coding.cftimeindex import _parse_array_of_cftime_strings -from . import (alignment, dtypes, duck_array_ops, formatting, groupby, - indexing, ops, pdcompat, resample, rolling, utils) -from .alignment import (align, _broadcast_helper, - _get_broadcast_dims_map_common_coords) -from .common import (ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, - _contains_datetime_like_objects) -from .coordinates import (DatasetCoordinates, LevelCoordinatesSource, - assert_coordinate_consistent, remap_label_indexers) -from .duck_array_ops import datetime_to_numeric -from .indexes import ( - Indexes, default_indexes, isel_variable_and_index, roll_index, +from . import ( + alignment, + dtypes, + duck_array_ops, + formatting, + groupby, + indexing, + ops, + pdcompat, + resample, + rolling, + utils, +) +from .alignment import align, _broadcast_helper, _get_broadcast_dims_map_common_coords +from .common import ( + ALL_DIMS, + DataWithCoords, + ImplementsDatasetReduce, + _contains_datetime_like_objects, +) +from .coordinates import ( + DatasetCoordinates, + LevelCoordinatesSource, + assert_coordinate_consistent, + remap_label_indexers, ) +from .duck_array_ops import datetime_to_numeric +from .indexes import Indexes, default_indexes, isel_variable_and_index, roll_index from .merge import ( - dataset_merge_method, dataset_update_method, merge_data_and_coords, - merge_variables) + dataset_merge_method, + dataset_update_method, + merge_data_and_coords, + merge_variables, +) from .options import OPTIONS, _get_keep_attrs from .pycompat import dask_array_type -from .utils import (Frozen, SortedKeysDict, _check_inplace, - decode_numpy_dict_values, either_dict_or_kwargs, hashable, - maybe_wrap_array) +from .utils import ( + Frozen, + SortedKeysDict, + _check_inplace, + decode_numpy_dict_values, + either_dict_or_kwargs, + hashable, + maybe_wrap_array, +) from .variable import IndexVariable, Variable, as_variable, broadcast_variables from ..plot.dataset_plot import _Dataset_PlotMethods @@ -43,6 +84,7 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import DatasetLike + try: from dask.delayed import Delayed except ImportError: @@ -50,16 +92,27 @@ # list of attributes of pd.DatetimeIndex that are ndarrays of time info -_DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute', - 'second', 'microsecond', 'nanosecond', 'date', - 'time', 'dayofyear', 'weekofyear', 'dayofweek', - 'quarter'] - - -def _get_virtual_variable(variables, key: Hashable, - level_vars: Mapping = None, - dim_sizes: Mapping = None, - ) -> Tuple[Hashable, Hashable, Variable]: +_DATETIMEINDEX_COMPONENTS = [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "date", + "time", + "dayofyear", + "weekofyear", + "dayofweek", + "quarter", +] + + +def _get_virtual_variable( + variables, key: Hashable, level_vars: Mapping = None, dim_sizes: Mapping = None +) -> Tuple[Hashable, Hashable, Variable]: """Get a virtual variable (e.g., 'time.year' or a MultiIndex level) from a dict of xarray.Variable objects (if possible) """ @@ -76,7 +129,7 @@ def _get_virtual_variable(variables, key: Hashable, if not isinstance(key, str): raise KeyError(key) - split_key = key.split('.', 1) + split_key = key.split(".", 1) if len(split_key) == 2: ref_name, var_name = split_key # type: str, Optional[str] elif len(split_key) == 1: @@ -104,9 +157,7 @@ def _get_virtual_variable(variables, key: Hashable, return ref_name, var_name, virtual_var -def calculate_dimensions( - variables: Mapping[Hashable, Variable] -) -> 'Dict[Any, int]': +def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> "Dict[Any, int]": """Calculate the dimensions corresponding to a set of variables. Returns dictionary mapping from dimension names to sizes. Raises ValueError @@ -118,15 +169,18 @@ def calculate_dimensions( for k, var in variables.items(): for dim, size in zip(var.dims, var.shape): if dim in scalar_vars: - raise ValueError('dimension %r already exists as a scalar ' - 'variable' % dim) + raise ValueError( + "dimension %r already exists as a scalar " "variable" % dim + ) if dim not in dims: dims[dim] = size last_used[dim] = k elif dims[dim] != size: - raise ValueError('conflicting sizes for dimension %r: ' - 'length %s on %r and length %s on %r' % - (dim, size, k, dims[dim], last_used[dim])) + raise ValueError( + "conflicting sizes for dimension %r: " + "length %s on %r and length %s on %r" + % (dim, size, k, dims[dim], last_used[dim]) + ) return dims @@ -134,8 +188,8 @@ def merge_indexes( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]], variables: Mapping[Hashable, Variable], coord_names: Set[Hashable], - append: bool = False -) -> 'Tuple[OrderedDict[Any, Variable], Set[Hashable]]': + append: bool = False, +) -> "Tuple[OrderedDict[Any, Variable], Set[Hashable]]": """Merge variables into multi-indexes. Not public API. Used in Dataset and DataArray set_index @@ -153,11 +207,14 @@ def merge_indexes( for n in var_names: var = variables[n] - if (current_index_variable is not None - and var.dims != current_index_variable.dims): + if ( + current_index_variable is not None + and var.dims != current_index_variable.dims + ): raise ValueError( "dimension mismatch between %r %s and %r %s" - % (dim, current_index_variable.dims, n, var.dims)) + % (dim, current_index_variable.dims, n, var.dims) + ) if current_index_variable is not None and append: current_index = current_index_variable.to_index() @@ -171,7 +228,7 @@ def merge_indexes( codes.extend(current_codes) levels.extend(current_index.levels) else: - names.append('%s_level_0' % dim) + names.append("%s_level_0" % dim) cat = pd.Categorical(current_index.values, ordered=True) codes.append(cat.codes) levels.append(cat.categories) @@ -192,8 +249,9 @@ def merge_indexes( vars_to_replace[dim] = IndexVariable(dim, idx) vars_to_remove.extend(var_names) - new_variables = OrderedDict([(k, v) for k, v in variables.items() - if k not in vars_to_remove]) + new_variables = OrderedDict( + [(k, v) for k, v in variables.items() if k not in vars_to_remove] + ) new_variables.update(vars_to_replace) new_coord_names = coord_names | set(vars_to_replace) new_coord_names -= set(vars_to_remove) @@ -207,14 +265,13 @@ def split_indexes( coord_names: Set[Hashable], level_coords: Mapping[Hashable, Hashable], drop: bool = False, -) -> 'Tuple[OrderedDict[Any, Variable], Set[Hashable]]': +) -> "Tuple[OrderedDict[Any, Variable], Set[Hashable]]": """Extract (multi-)indexes (levels) as variables. Not public API. Used in Dataset and DataArray reset_index methods. """ - if (isinstance(dims_or_levels, str) or - not isinstance(dims_or_levels, Sequence)): + if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): dims_or_levels = [dims_or_levels] dim_levels = defaultdict(list) # type: DefaultDict[Any, List[Hashable]] @@ -236,7 +293,7 @@ def split_indexes( else: vars_to_remove.append(d) if not drop: - vars_to_create[str(d) + '_'] = Variable(d, index) + vars_to_create[str(d) + "_"] = Variable(d, index) for d, levs in dim_levels.items(): index = variables[d].to_index() @@ -260,40 +317,42 @@ def split_indexes( return new_variables, new_coord_names -def _assert_empty(args: tuple, msg: str = '%s') -> None: +def _assert_empty(args: tuple, msg: str = "%s") -> None: if args: raise ValueError(msg % args) -def as_dataset(obj: Any) -> 'Dataset': +def as_dataset(obj: Any) -> "Dataset": """Cast the given object to a Dataset. Handles Datasets, DataArrays and dictionaries of variables. A new Dataset object is only created if the provided object is not already one. """ - if hasattr(obj, 'to_dataset'): + if hasattr(obj, "to_dataset"): obj = obj.to_dataset() if not isinstance(obj, Dataset): obj = Dataset(obj) return obj -class DataVariables(Mapping[Hashable, 'Union[DataArray, Dataset]']): - def __init__(self, dataset: 'Dataset'): +class DataVariables(Mapping[Hashable, "Union[DataArray, Dataset]"]): + def __init__(self, dataset: "Dataset"): self._dataset = dataset def __iter__(self) -> Iterator[Hashable]: - return (key for key in self._dataset._variables - if key not in self._dataset._coord_names) + return ( + key + for key in self._dataset._variables + if key not in self._dataset._coord_names + ) def __len__(self) -> int: return len(self._dataset._variables) - len(self._dataset._coord_names) def __contains__(self, key) -> bool: - return (key in self._dataset._variables - and key not in self._dataset._coord_names) + return key in self._dataset._variables and key not in self._dataset._coord_names - def __getitem__(self, key) -> 'Union[DataArray, Dataset]': + def __getitem__(self, key) -> "Union[DataArray, Dataset]": if key not in self._dataset._coord_names: return self._dataset[key] else: @@ -309,17 +368,20 @@ def variables(self) -> Mapping[Hashable, Variable]: def _ipython_key_completions_(self): """Provide method for the key-autocompletions in IPython. """ - return [key for key in self._dataset._ipython_key_completions_() - if key not in self._dataset._coord_names] + return [ + key + for key in self._dataset._ipython_key_completions_() + if key not in self._dataset._coord_names + ] class _LocIndexer: - def __init__(self, dataset: 'Dataset'): + def __init__(self, dataset: "Dataset"): self.dataset = dataset - def __getitem__(self, key: Mapping[Hashable, Any]) -> 'Dataset': + def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": if not utils.is_dict_like(key): - raise TypeError('can only lookup dictionaries from Dataset.loc') + raise TypeError("can only lookup dictionaries from Dataset.loc") return self.dataset.sel(key) @@ -336,6 +398,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): One dimensional variables with name equal to their dimension are index coordinates used for label based indexing. """ + _groupby_cls = groupby.DatasetGroupBy _rolling_cls = rolling.DatasetRolling _coarsen_cls = rolling.DatasetCoarsen @@ -345,12 +408,15 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Hashable, Union[ - 'DataArray', - Variable, - Tuple[Hashable, Any], - Tuple[Sequence[Hashable], Any], - ]] = None, + data_vars: Mapping[ + Hashable, + Union[ + "DataArray", + Variable, + Tuple[Hashable, Any], + Tuple[Sequence[Hashable], Any], + ], + ] = None, coords: Mapping[Hashable, Any] = None, attrs: Mapping = None, compat=None, @@ -404,15 +470,16 @@ def __init__( """ if compat is not None: warnings.warn( - 'The `compat` argument to Dataset is deprecated and will be ' - 'removed in 0.13.' - 'Instead, use `merge` to control how variables are combined', - FutureWarning, stacklevel=2) + "The `compat` argument to Dataset is deprecated and will be " + "removed in 0.13." + "Instead, use `merge` to control how variables are combined", + FutureWarning, + stacklevel=2, + ) else: - compat = 'broadcast_equals' + compat = "broadcast_equals" - self._variables = \ - OrderedDict() # type: OrderedDict[Any, Variable] + self._variables = OrderedDict() # type: OrderedDict[Any, Variable] self._coord_names = set() # type: Set[Hashable] self._dims = {} # type: Dict[Any, int] self._attrs = None # type: Optional[OrderedDict] @@ -436,21 +503,24 @@ def _set_init_vars_and_dims(self, data_vars, coords, compat): """ both_data_and_coords = [k for k in data_vars if k in coords] if both_data_and_coords: - raise ValueError('variables %r are found in both data_vars and ' - 'coords' % both_data_and_coords) + raise ValueError( + "variables %r are found in both data_vars and " + "coords" % both_data_and_coords + ) if isinstance(coords, Dataset): coords = coords.variables variables, coord_names, dims = merge_data_and_coords( - data_vars, coords, compat=compat) + data_vars, coords, compat=compat + ) self._variables = variables self._coord_names = coord_names self._dims = dims @classmethod - def load_store(cls, store, decoder=None) -> 'Dataset': + def load_store(cls, store, decoder=None) -> "Dataset": """Create a new dataset from the contents of a backends.*DataStore object """ @@ -473,7 +543,7 @@ def variables(self) -> Mapping[Hashable, Variable]: return Frozen(self._variables) @property - def attrs(self) -> 'OrderedDict[Any, Any]': + def attrs(self) -> "OrderedDict[Any, Any]": """Dictionary of global attributes on this dataset """ if self._attrs is None: @@ -523,7 +593,7 @@ def sizes(self) -> Mapping[Hashable, int]: """ return self.dims - def load(self, **kwargs) -> 'Dataset': + def load(self, **kwargs) -> "Dataset": """Manually trigger loading of this dataset's data from disk or a remote source into memory and return this dataset. @@ -542,8 +612,11 @@ def load(self, **kwargs) -> 'Dataset': dask.array.compute """ # access .data to coerce everything to numpy or dask arrays - lazy_data = {k: v._data for k, v in self.variables.items() - if isinstance(v._data, dask_array_type)} + lazy_data = { + k: v._data + for k, v in self.variables.items() + if isinstance(v._data, dask_array_type) + } if lazy_data: import dask.array as da @@ -568,36 +641,55 @@ def __dask_graph__(self): else: try: from dask.highlevelgraph import HighLevelGraph + return HighLevelGraph.merge(*graphs.values()) except ImportError: from dask import sharedict + return sharedict.merge(*graphs.values()) def __dask_keys__(self): import dask - return [v.__dask_keys__() for v in self.variables.values() - if dask.is_dask_collection(v)] + + return [ + v.__dask_keys__() + for v in self.variables.values() + if dask.is_dask_collection(v) + ] def __dask_layers__(self): import dask - return sum([v.__dask_layers__() for v in self.variables.values() if - dask.is_dask_collection(v)], ()) + + return sum( + [ + v.__dask_layers__() + for v in self.variables.values() + if dask.is_dask_collection(v) + ], + (), + ) @property def __dask_optimize__(self): import dask.array as da + return da.Array.__dask_optimize__ @property def __dask_scheduler__(self): import dask.array as da + return da.Array.__dask_scheduler__ def __dask_postcompute__(self): import dask - info = [(True, k, v.__dask_postcompute__()) - if dask.is_dask_collection(v) else - (False, k, v) for k, v in self._variables.items()] + + info = [ + (True, k, v.__dask_postcompute__()) + if dask.is_dask_collection(v) + else (False, k, v) + for k, v in self._variables.items() + ] args = ( info, self._coord_names, @@ -611,9 +703,13 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): import dask - info = [(True, k, v.__dask_postpersist__()) - if dask.is_dask_collection(v) else - (False, k, v) for k, v in self._variables.items()] + + info = [ + (True, k, v.__dask_postpersist__()) + if dask.is_dask_collection(v) + else (False, k, v) + for k, v in self._variables.items() + ] args = ( info, self._coord_names, @@ -654,7 +750,7 @@ def _dask_postpersist(dsk, info, *args): return Dataset._construct_direct(variables, *args) - def compute(self, **kwargs) -> 'Dataset': + def compute(self, **kwargs) -> "Dataset": """Manually trigger loading of this dataset's data from disk or a remote source into memory and return a new dataset. The original is left unaltered. @@ -676,12 +772,15 @@ def compute(self, **kwargs) -> 'Dataset': new = self.copy(deep=False) return new.load(**kwargs) - def _persist_inplace(self, **kwargs) -> 'Dataset': + def _persist_inplace(self, **kwargs) -> "Dataset": """Persist all Dask arrays in memory """ # access .data to coerce everything to numpy or dask arrays - lazy_data = {k: v._data for k, v in self.variables.items() - if isinstance(v._data, dask_array_type)} + lazy_data = { + k: v._data + for k, v in self.variables.items() + if isinstance(v._data, dask_array_type) + } if lazy_data: import dask @@ -693,7 +792,7 @@ def _persist_inplace(self, **kwargs) -> 'Dataset': return self - def persist(self, **kwargs) -> 'Dataset': + def persist(self, **kwargs) -> "Dataset": """ Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -715,8 +814,16 @@ def persist(self, **kwargs) -> 'Dataset': return new._persist_inplace(**kwargs) @classmethod - def _construct_direct(cls, variables, coord_names, dims, attrs=None, - indexes=None, encoding=None, file_obj=None): + def _construct_direct( + cls, + variables, + coord_names, + dims, + attrs=None, + indexes=None, + encoding=None, + file_obj=None, + ): """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -743,14 +850,14 @@ def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None): # https://github.com/python/mypy/issues/1803 def _replace( # type: ignore self, - variables: 'OrderedDict[Any, Variable]' = None, + variables: "OrderedDict[Any, Variable]" = None, coord_names: Set[Hashable] = None, dims: Dict[Any, int] = None, - attrs: 'Optional[OrderedDict]' = __default, - indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default, + attrs: "Optional[OrderedDict]" = __default, + indexes: "Optional[OrderedDict[Any, pd.Index]]" = __default, encoding: Optional[dict] = __default, inplace: bool = False, - ) -> 'Dataset': + ) -> "Dataset": """Fastpath constructor for internal use. Returns an object with optionally with replaced attributes. @@ -787,30 +894,32 @@ def _replace( # type: ignore if encoding is self.__default: encoding = copy.copy(self._encoding) obj = self._construct_direct( - variables, coord_names, dims, attrs, indexes, encoding) + variables, coord_names, dims, attrs, indexes, encoding + ) return obj def _replace_with_new_dims( # type: ignore self, - variables: 'OrderedDict[Any, Variable]', + variables: "OrderedDict[Any, Variable]", coord_names: set = None, - attrs: Optional['OrderedDict'] = __default, - indexes: 'OrderedDict[Any, pd.Index]' = __default, + attrs: Optional["OrderedDict"] = __default, + indexes: "OrderedDict[Any, pd.Index]" = __default, inplace: bool = False, - ) -> 'Dataset': + ) -> "Dataset": """Replace variables with recalculated dimensions.""" dims = calculate_dimensions(variables) return self._replace( - variables, coord_names, dims, attrs, indexes, inplace=inplace) + variables, coord_names, dims, attrs, indexes, inplace=inplace + ) def _replace_vars_and_dims( # type: ignore self, - variables: 'OrderedDict[Any, Variable]', + variables: "OrderedDict[Any, Variable]", coord_names: set = None, dims: Dict[Any, int] = None, - attrs: 'OrderedDict' = __default, + attrs: "OrderedDict" = __default, inplace: bool = False, - ) -> 'Dataset': + ) -> "Dataset": """Deprecated version of _replace_with_new_dims(). Unlike _replace_with_new_dims(), this method always recalculates @@ -819,9 +928,10 @@ def _replace_vars_and_dims( # type: ignore if dims is None: dims = calculate_dimensions(variables) return self._replace( - variables, coord_names, dims, attrs, indexes=None, inplace=inplace) + variables, coord_names, dims, attrs, indexes=None, inplace=inplace + ) - def _overwrite_indexes(self, indexes: Mapping[Any, pd.Index]) -> 'Dataset': + def _overwrite_indexes(self, indexes: Mapping[Any, pd.Index]) -> "Dataset": if not indexes: return self @@ -841,7 +951,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, pd.Index]) -> 'Dataset': obj = obj.rename(dim_names) return obj - def copy(self, deep: bool = False, data: Mapping = None) -> 'Dataset': + def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -935,34 +1045,37 @@ def copy(self, deep: bool = False, data: Mapping = None) -> 'Dataset': pandas.DataFrame.copy """ # noqa if data is None: - variables = OrderedDict((k, v.copy(deep=deep)) - for k, v in self._variables.items()) + variables = OrderedDict( + (k, v.copy(deep=deep)) for k, v in self._variables.items() + ) elif not utils.is_dict_like(data): - raise ValueError('Data must be dict-like') + raise ValueError("Data must be dict-like") else: var_keys = set(self.data_vars.keys()) data_keys = set(data.keys()) keys_not_in_vars = data_keys - var_keys if keys_not_in_vars: raise ValueError( - 'Data must only contain variables in original ' - 'dataset. Extra variables: {}' - .format(keys_not_in_vars)) + "Data must only contain variables in original " + "dataset. Extra variables: {}".format(keys_not_in_vars) + ) keys_missing_from_data = var_keys - data_keys if keys_missing_from_data: raise ValueError( - 'Data must contain all variables in original ' - 'dataset. Data is missing {}' - .format(keys_missing_from_data)) - variables = OrderedDict((k, v.copy(deep=deep, data=data.get(k))) - for k, v in self._variables.items()) + "Data must contain all variables in original " + "dataset. Data is missing {}".format(keys_missing_from_data) + ) + variables = OrderedDict( + (k, v.copy(deep=deep, data=data.get(k))) + for k, v in self._variables.items() + ) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) return self._replace(variables, attrs=attrs) @property - def _level_coords(self) -> 'OrderedDict[str, Hashable]': + def _level_coords(self) -> "OrderedDict[str, Hashable]": """Return a mapping of all MultiIndex levels and their corresponding coordinate name. """ @@ -974,7 +1087,7 @@ def _level_coords(self) -> 'OrderedDict[str, Hashable]': level_coords.update({lname: dim for lname in level_names}) return level_coords - def _copy_listed(self, names: Iterable[Hashable]) -> 'Dataset': + def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. """ @@ -987,7 +1100,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> 'Dataset': variables[name] = self._variables[name] except KeyError: ref_name, var_name, var = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims) + self._variables, name, self._level_coords, self.dims + ) variables[var_name] = var if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) @@ -1009,7 +1123,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> 'Dataset': return self._replace(variables, coord_names, dims, indexes=indexes) - def _construct_dataarray(self, name: Hashable) -> 'DataArray': + def _construct_dataarray(self, name: Hashable) -> "DataArray": """Construct a DataArray by indexing this dataset """ from .dataarray import DataArray @@ -1018,7 +1132,8 @@ def _construct_dataarray(self, name: Hashable) -> 'DataArray': variable = self._variables[name] except KeyError: _, name, variable = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims) + self._variables, name, self._level_coords, self.dims + ) needed_dims = set(variable.dims) @@ -1030,16 +1145,16 @@ def _construct_dataarray(self, name: Hashable) -> 'DataArray': if self._indexes is None: indexes = None else: - indexes = OrderedDict((k, v) for k, v in self._indexes.items() - if k in coords) + indexes = OrderedDict( + (k, v) for k, v in self._indexes.items() if k in coords + ) - return DataArray(variable, coords, name=name, indexes=indexes, - fastpath=True) + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) - def __copy__(self) -> 'Dataset': + def __copy__(self) -> "Dataset": return self.copy(deep=False) - def __deepcopy__(self, memo=None) -> 'Dataset': + def __deepcopy__(self, memo=None) -> "Dataset": # memo does nothing but is required for compatibility with # copy.deepcopy return self.copy(deep=True) @@ -1054,8 +1169,12 @@ def _attr_sources(self) -> List[Mapping[Hashable, Any]]: def _item_sources(self) -> List[Mapping[Hashable, Any]]: """List of places to look-up items for key-completion """ - return [self.data_vars, self.coords, {d: self[d] for d in self.dims}, - LevelCoordinatesSource(self)] + return [ + self.data_vars, + self.coords, + {d: self[d] for d in self.dims}, + LevelCoordinatesSource(self), + ] def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether @@ -1073,10 +1192,12 @@ def __iter__(self) -> Iterator[Hashable]: return iter(self.data_vars) def __array__(self, dtype=None): - raise TypeError('cannot directly convert an xarray.Dataset into a ' - 'numpy array. Instead, create an xarray.DataArray ' - 'first, either with indexing on the Dataset or by ' - 'invoking the `to_array()` method.') + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_array()` method." + ) @property def nbytes(self) -> int: @@ -1089,7 +1210,7 @@ def loc(self) -> _LocIndexer: """ return _LocIndexer(self) - def __getitem__(self, key: object) -> 'Union[DataArray, Dataset]': + def __getitem__(self, key: object) -> "Union[DataArray, Dataset]": """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. @@ -1115,8 +1236,9 @@ def __setitem__(self, key: Hashable, value) -> None: variable. """ if utils.is_dict_like(key): - raise NotImplementedError('cannot yet use a dictionary as a key ' - 'to set Dataset values') + raise NotImplementedError( + "cannot yet use a dictionary as a key " "to set Dataset values" + ) self.update({key: value}) @@ -1131,7 +1253,7 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore - def _all_compat(self, other: 'Dataset', compat_str: str) -> bool: + def _all_compat(self, other: "Dataset", compat_str: str) -> bool: """Helper function for equals and identical """ @@ -1140,11 +1262,11 @@ def _all_compat(self, other: 'Dataset', compat_str: str) -> bool: def compat(x: Variable, y: Variable) -> bool: return getattr(x, compat_str)(y) - return (self._coord_names == other._coord_names and - utils.dict_equiv(self._variables, other._variables, - compat=compat)) + return self._coord_names == other._coord_names and utils.dict_equiv( + self._variables, other._variables, compat=compat + ) - def broadcast_equals(self, other: 'Dataset') -> bool: + def broadcast_equals(self, other: "Dataset") -> bool: """Two Datasets are broadcast equal if they are equal after broadcasting all variables against each other. @@ -1158,11 +1280,11 @@ def broadcast_equals(self, other: 'Dataset') -> bool: Dataset.identical """ try: - return self._all_compat(other, 'broadcast_equals') + return self._all_compat(other, "broadcast_equals") except (TypeError, AttributeError): return False - def equals(self, other: 'Dataset') -> bool: + def equals(self, other: "Dataset") -> bool: """Two Datasets are equal if they have matching variables and coordinates, all of which are equal. @@ -1178,11 +1300,11 @@ def equals(self, other: 'Dataset') -> bool: Dataset.identical """ try: - return self._all_compat(other, 'equals') + return self._all_compat(other, "equals") except (TypeError, AttributeError): return False - def identical(self, other: 'Dataset') -> bool: + def identical(self, other: "Dataset") -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates. @@ -1192,13 +1314,14 @@ def identical(self, other: 'Dataset') -> bool: Dataset.equals """ try: - return (utils.dict_equiv(self.attrs, other.attrs) - and self._all_compat(other, 'identical')) + return utils.dict_equiv(self.attrs, other.attrs) and self._all_compat( + other, "identical" + ) except (TypeError, AttributeError): return False @property - def indexes(self) -> 'Mapping[Any, pd.Index]': + def indexes(self) -> "Mapping[Any, pd.Index]": """Mapping of pandas.Index objects used for label based indexing """ if self._indexes is None: @@ -1219,10 +1342,8 @@ def data_vars(self) -> DataVariables: return DataVariables(self) def set_coords( - self, - names: 'Union[Hashable, Iterable[Hashable]]', - inplace: bool = None - ) -> 'Dataset': + self, names: "Union[Hashable, Iterable[Hashable]]", inplace: bool = None + ) -> "Dataset": """Given names of one or more variables, set them as coordinates Parameters @@ -1257,10 +1378,10 @@ def set_coords( def reset_coords( self, - names: 'Union[Hashable, Iterable[Hashable], None]' = None, + names: "Union[Hashable, Iterable[Hashable], None]" = None, drop: bool = False, - inplace: bool = None - ) -> 'Dataset': + inplace: bool = None, + ) -> "Dataset": """Given names of coordinates, reset them to become variables Parameters @@ -1291,8 +1412,8 @@ def reset_coords( bad_coords = set(names) & set(self.dims) if bad_coords: raise ValueError( - 'cannot remove index coordinates with reset_coords: %s' - % bad_coords) + "cannot remove index coordinates with reset_coords: %s" % bad_coords + ) obj = self if inplace else self.copy() obj._coord_names.difference_update(names) if drop: @@ -1300,10 +1421,11 @@ def reset_coords( del obj._variables[name] return obj - def dump_to_store(self, store: 'AbstractDataStore', **kwargs) -> None: + def dump_to_store(self, store: "AbstractDataStore", **kwargs) -> None: """Store dataset contents to a backends.*DataStore object. """ from ..backends.api import dump_to_store + # TODO: rename and/or cleanup this method to make it more consistent # with to_netcdf() dump_to_store(self, store, **kwargs) @@ -1311,14 +1433,14 @@ def dump_to_store(self, store: 'AbstractDataStore', **kwargs) -> None: def to_netcdf( self, path=None, - mode: str = 'w', + mode: str = "w", format: str = None, group: str = None, engine: str = None, encoding: Mapping = None, unlimited_dims: Iterable[Hashable] = None, compute: bool = True, - ) -> Union[bytes, 'Delayed', None]: + ) -> Union[bytes, "Delayed", None]: """Write dataset contents to a netCDF file. Parameters @@ -1385,10 +1507,18 @@ def to_netcdf( if encoding is None: encoding = {} from ..backends.api import to_netcdf - return to_netcdf(self, path, mode, format=format, group=group, - engine=engine, encoding=encoding, - unlimited_dims=unlimited_dims, - compute=compute) + + return to_netcdf( + self, + path, + mode, + format=format, + group=group, + engine=engine, + encoding=encoding, + unlimited_dims=unlimited_dims, + compute=compute, + ) def to_zarr( self, @@ -1399,8 +1529,8 @@ def to_zarr( encoding: Mapping = None, compute: bool = True, consolidated: bool = False, - append_dim: Hashable = None - ) -> 'ZarrStore': + append_dim: Hashable = None, + ) -> "ZarrStore": """Write dataset contents to a zarr group. .. note:: Experimental @@ -1441,24 +1571,34 @@ def to_zarr( """ if encoding is None: encoding = {} - if (mode == 'a') or (append_dim is not None): + if (mode == "a") or (append_dim is not None): if mode is None: - mode = 'a' - elif mode != 'a': + mode = "a" + elif mode != "a": raise ValueError( "append_dim was set along with mode='{}', either set " "mode='a' or don't set it.".format(mode) ) elif mode is None: - mode = 'w-' - if mode not in ['w', 'w-', 'a']: + mode = "w-" + if mode not in ["w", "w-", "a"]: # TODO: figure out how to handle 'r+' - raise ValueError("The only supported options for mode are 'w'," - "'w-' and 'a'.") + raise ValueError( + "The only supported options for mode are 'w'," "'w-' and 'a'." + ) from ..backends.api import to_zarr - return to_zarr(self, store=store, mode=mode, synchronizer=synchronizer, - group=group, encoding=encoding, compute=compute, - consolidated=consolidated, append_dim=append_dim) + + return to_zarr( + self, + store=store, + mode=mode, + synchronizer=synchronizer, + group=group, + encoding=encoding, + compute=compute, + consolidated=consolidated, + append_dim=append_dim, + ) def __repr__(self) -> str: return formatting.dataset_repr(self) @@ -1480,24 +1620,24 @@ def info(self, buf=None) -> None: buf = sys.stdout lines = [] - lines.append('xarray.Dataset {') - lines.append('dimensions:') + lines.append("xarray.Dataset {") + lines.append("dimensions:") for name, size in self.dims.items(): - lines.append('\t{name} = {size} ;'.format(name=name, size=size)) - lines.append('\nvariables:') + lines.append("\t{name} = {size} ;".format(name=name, size=size)) + lines.append("\nvariables:") for name, da in self.variables.items(): - dims = ', '.join(da.dims) - lines.append('\t{type} {name}({dims}) ;'.format( - type=da.dtype, name=name, dims=dims)) + dims = ", ".join(da.dims) + lines.append( + "\t{type} {name}({dims}) ;".format(type=da.dtype, name=name, dims=dims) + ) for k, v in da.attrs.items(): - lines.append('\t\t{name}:{k} = {v} ;'.format(name=name, k=k, - v=v)) - lines.append('\n// global attributes:') + lines.append("\t\t{name}:{k} = {v} ;".format(name=name, k=k, v=v)) + lines.append("\n// global attributes:") for k, v in self.attrs.items(): - lines.append('\t:{k} = {v} ;'.format(k=k, v=v)) - lines.append('}') + lines.append("\t:{k} = {v} ;".format(k=k, v=v)) + lines.append("}") - buf.write('\n'.join(lines)) + buf.write("\n".join(lines)) @property def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: @@ -1509,21 +1649,19 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: if v.chunks is not None: for dim, c in zip(v.dims, v.chunks): if dim in chunks and c != chunks[dim]: - raise ValueError('inconsistent chunks') + raise ValueError("inconsistent chunks") chunks[dim] = c return Frozen(SortedKeysDict(chunks)) def chunk( self, chunks: Union[ - None, - Number, - Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]] + None, Number, Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]] ] = None, - name_prefix: str = 'xarray-', + name_prefix: str = "xarray-", token: str = None, - lock: bool = False - ) -> 'Dataset': + lock: bool = False, + ) -> "Dataset": """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -1556,7 +1694,8 @@ def chunk( except ImportError: # raise the usual error if dask is entirely missing import dask # noqa - raise ImportError('xarray requires dask version 0.9 or newer') + + raise ImportError("xarray requires dask version 0.9 or newer") if isinstance(chunks, Number): chunks = dict.fromkeys(self.dims, chunks) @@ -1564,8 +1703,10 @@ def chunk( if chunks is not None: bad_dims = chunks.keys() - self.dims.keys() if bad_dims: - raise ValueError('some chunks keys are not dimensions on this ' - 'object: %s' % bad_dims) + raise ValueError( + "some chunks keys are not dimensions on this " + "object: %s" % bad_dims + ) def selkeys(dict_, keys): if dict_ is None: @@ -1578,17 +1719,18 @@ def maybe_chunk(name, var, chunks): chunks = None if var.ndim > 0: token2 = tokenize(name, token if token else var._data) - name2 = '%s%s-%s' % (name_prefix, name, token2) + name2 = "%s%s-%s" % (name_prefix, name, token2) return var.chunk(chunks, name=name2, lock=lock) else: return var - variables = OrderedDict([(k, maybe_chunk(k, v, chunks)) - for k, v in self.variables.items()]) + variables = OrderedDict( + [(k, maybe_chunk(k, v, chunks)) for k, v in self.variables.items()] + ) return self._replace(variables) def _validate_indexers( - self, indexers: Mapping, + self, indexers: Mapping ) -> List[Tuple[Any, Union[slice, Variable]]]: """ Here we make sure + indexer has a valid keys @@ -1616,16 +1758,16 @@ def _validate_indexers( elif isinstance(v, tuple): v = as_variable(v) elif isinstance(v, Dataset): - raise TypeError('cannot use a Dataset as an indexer') + raise TypeError("cannot use a Dataset as an indexer") elif isinstance(v, Sequence) and len(v) == 0: - v = IndexVariable((k, ), np.zeros((0,), dtype='int64')) + v = IndexVariable((k,), np.zeros((0,), dtype="int64")) else: v = np.asarray(v) - if v.dtype.kind == 'U' or v.dtype.kind == 'S': + if v.dtype.kind == "U" or v.dtype.kind == "S": index = self.indexes[k] if isinstance(index, pd.DatetimeIndex): - v = v.astype('datetime64[ns]') + v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): v = _parse_array_of_cftime_strings(v, index.date_type) @@ -1636,7 +1778,8 @@ def _validate_indexers( else: raise IndexError( "Unlabeled multi-dimensional array cannot be " - "used for indexing: {}".format(k)) + "used for indexing: {}".format(k) + ) if v.ndim == 1: v = v.to_index_variable() @@ -1660,12 +1803,13 @@ def _get_indexers_coords_and_indexes(self, indexers): for k, v in indexers.items(): if isinstance(v, DataArray): v_coords = v.coords - if v.dtype.kind == 'b': + if v.dtype.kind == "b": if v.ndim != 1: # we only support 1-d boolean array raise ValueError( - '{:d}d-boolean array is used for indexing along ' - 'dimension {!r}, but only 1d boolean arrays are ' - 'supported.'.format(v.ndim, k)) + "{:d}d-boolean array is used for indexing along " + "dimension {!r}, but only 1d boolean arrays are " + "supported.".format(v.ndim, k) + ) # Make sure in case of boolean DataArray, its # coordinate also should be indexed. v_coords = v[v.values.nonzero()[0]].coords @@ -1693,7 +1837,7 @@ def isel( indexers: Mapping[Hashable, Any] = None, drop: bool = False, **indexers_kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Returns a new dataset with each array indexed along the specified dimension(s). @@ -1734,7 +1878,7 @@ def isel( DataArray.isel """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") indexers_list = self._validate_indexers(indexers) @@ -1748,7 +1892,8 @@ def isel( if name in self.indexes: new_var, new_index = isel_variable_and_index( - name, var, self.indexes[name], var_indexers) + name, var, self.indexes[name], var_indexers + ) if new_index is not None: indexes[name] = new_index else: @@ -1757,19 +1902,14 @@ def isel( variables[name] = new_var coord_names = set(variables).intersection(self._coord_names) - selected = self._replace_with_new_dims( - variables, coord_names, indexes) + selected = self._replace_with_new_dims(variables, coord_names, indexes) # Extract coordinates from indexers - coord_vars, new_indexes = ( - selected._get_indexers_coords_and_indexes(indexers)) + coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers) variables.update(coord_vars) indexes.update(new_indexes) - coord_names = (set(variables) - .intersection(self._coord_names) - .union(coord_vars)) - return self._replace_with_new_dims( - variables, coord_names, indexes=indexes) + coord_names = set(variables).intersection(self._coord_names).union(coord_vars) + return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def sel( self, @@ -1778,7 +1918,7 @@ def sel( tolerance: Number = None, drop: bool = False, **indexers_kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -1841,13 +1981,14 @@ def sel( Dataset.isel DataArray.sel """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel') + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") pos_indexers, new_indexes = remap_label_indexers( - self, indexers=indexers, method=method, tolerance=tolerance) + self, indexers=indexers, method=method, tolerance=tolerance + ) result = self.isel(indexers=pos_indexers, drop=drop) return result._overwrite_indexes(new_indexes) - def isel_points(self, dim: Any = 'points', **indexers: Any) -> 'Dataset': + def isel_points(self, dim: Any = "points", **indexers: Any) -> "Dataset": """Returns a new dataset with each array indexed pointwise along the specified dimension(s). @@ -1885,15 +2026,18 @@ def isel_points(self, dim: Any = 'points', **indexers: Any) -> 'Dataset': Dataset.sel_points DataArray.isel_points """ # noqa - warnings.warn('Dataset.isel_points is deprecated: use Dataset.isel()' - 'instead.', DeprecationWarning, stacklevel=2) + warnings.warn( + "Dataset.isel_points is deprecated: use Dataset.isel()" "instead.", + DeprecationWarning, + stacklevel=2, + ) indexer_dims = set(indexers) def take(variable, slices): # Note: remove helper function when once when numpy # supports vindex https://github.com/numpy/numpy/pull/6075 - if hasattr(variable.data, 'vindex'): + if hasattr(variable.data, "vindex"): # Special case for dask backed arrays to use vectorised list # indexing sel = variable.data.vindex[slices] @@ -1903,12 +2047,15 @@ def take(variable, slices): return sel def relevant_keys(mapping): - return [k for k, v in mapping.items() - if any(d in indexer_dims for d in v.dims)] + return [ + k for k, v in mapping.items() if any(d in indexer_dims for d in v.dims) + ] coords = relevant_keys(self.coords) - indexers = [(k, np.asarray(v)) # type: ignore - for k, v in indexers.items()] + indexers = [ + (k, np.asarray(v)) # type: ignore + for k, v in indexers.items() + ] indexers_dict = dict(indexers) non_indexed_dims = set(self.dims) - indexer_dims non_indexed_coords = set(self.coords) - set(coords) @@ -1918,56 +2065,58 @@ def relevant_keys(mapping): for k, v in indexers: if k not in self.dims: raise ValueError("dimension %s does not exist" % k) - if v.dtype.kind != 'i': # type: ignore - raise TypeError('Indexers must be integers') + if v.dtype.kind != "i": # type: ignore + raise TypeError("Indexers must be integers") if v.ndim != 1: # type: ignore - raise ValueError('Indexers must be 1 dimensional') + raise ValueError("Indexers must be 1 dimensional") # all the indexers should have the same length lengths = {len(v) for k, v in indexers} if len(lengths) > 1: - raise ValueError('All indexers must be the same length') + raise ValueError("All indexers must be the same length") # Existing dimensions are not valid choices for the dim argument if isinstance(dim, str): if dim in self.dims: # dim is an invalid string - raise ValueError('Existing dimension names are not valid ' - 'choices for the dim argument in sel_points') + raise ValueError( + "Existing dimension names are not valid " + "choices for the dim argument in sel_points" + ) - elif hasattr(dim, 'dims'): + elif hasattr(dim, "dims"): # dim is a DataArray or Coordinate if dim.name in self.dims: # dim already exists - raise ValueError('Existing dimensions are not valid choices ' - 'for the dim argument in sel_points') + raise ValueError( + "Existing dimensions are not valid choices " + "for the dim argument in sel_points" + ) # Set the new dim_name, and optionally the new dim coordinate # dim is either an array-like or a string if not utils.is_scalar(dim): # dim is array like get name or assign 'points', get as variable - dim_name = 'points' if not hasattr(dim, 'name') else dim.name + dim_name = "points" if not hasattr(dim, "name") else dim.name dim_coord = as_variable(dim, name=dim_name) else: # dim is a string dim_name = dim dim_coord = None # type: ignore - reordered = self.transpose( - *list(indexer_dims), *list(non_indexed_dims)) + reordered = self.transpose(*list(indexer_dims), *list(non_indexed_dims)) variables = OrderedDict() # type: ignore for name, var in reordered.variables.items(): - if name in indexers_dict or any( - d in indexer_dims for d in var.dims): + if name in indexers_dict or any(d in indexer_dims for d in var.dims): # slice if var is an indexer or depends on an indexed dim - slc = [indexers_dict[k] - if k in indexers_dict - else slice(None) for k in var.dims] + slc = [ + indexers_dict[k] if k in indexers_dict else slice(None) + for k in var.dims + ] - var_dims = [dim_name] + [d for d in var.dims - if d in non_indexed_dims] + var_dims = [dim_name] + [d for d in var.dims if d in non_indexed_dims] selection = take(var, tuple(slc)) var_subset = type(var)(var_dims, selection, var.attrs) variables[name] = var_subset @@ -1985,9 +2134,13 @@ def relevant_keys(mapping): dset.coords[dim_name] = dim_coord return dset - def sel_points(self, dim: Any = 'points', method: str = None, - tolerance: Number = None, - **indexers: Any): + def sel_points( + self, + dim: Any = "points", + method: str = None, + tolerance: Number = None, + **indexers: Any + ): """Returns a new dataset with each array indexed pointwise by tick labels along the specified dimension(s). @@ -2039,17 +2192,20 @@ def sel_points(self, dim: Any = 'points', method: str = None, Dataset.isel_points DataArray.sel_points """ # noqa - warnings.warn('Dataset.sel_points is deprecated: use Dataset.sel()' - 'instead.', DeprecationWarning, stacklevel=2) + warnings.warn( + "Dataset.sel_points is deprecated: use Dataset.sel()" "instead.", + DeprecationWarning, + stacklevel=2, + ) pos_indexers, _ = indexing.remap_label_indexers( self, indexers, method=method, tolerance=tolerance ) return self.isel_points(dim=dim, **pos_indexers) - def broadcast_like(self, - other: Union['Dataset', 'DataArray'], - exclude: Iterable[Hashable] = None) -> 'Dataset': + def broadcast_like( + self, other: Union["Dataset", "DataArray"], exclude: Iterable[Hashable] = None + ) -> "Dataset": """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -2065,21 +2221,20 @@ def broadcast_like(self, exclude = set() else: exclude = set(exclude) - args = align(other, self, join='outer', copy=False, exclude=exclude) + args = align(other, self, join="outer", copy=False, exclude=exclude) - dims_map, common_coords = _get_broadcast_dims_map_common_coords( - args, exclude) + dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) return _broadcast_helper(args[1], exclude, dims_map, common_coords) def reindex_like( - self, - other: Union['Dataset', 'DataArray'], - method: str = None, - tolerance: Number = None, - copy: bool = True, - fill_value: Any = dtypes.NA - ) -> 'Dataset': + self, + other: Union["Dataset", "DataArray"], + method: str = None, + tolerance: Number = None, + copy: bool = True, + fill_value: Any = dtypes.NA, + ) -> "Dataset": """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2125,8 +2280,13 @@ def reindex_like( align """ indexers = alignment.reindex_like_indexers(self, other) - return self.reindex(indexers=indexers, method=method, copy=copy, - fill_value=fill_value, tolerance=tolerance) + return self.reindex( + indexers=indexers, + method=method, + copy=copy, + fill_value=fill_value, + tolerance=tolerance, + ) def reindex( self, @@ -2136,7 +2296,7 @@ def reindex( copy: bool = True, fill_value: Any = dtypes.NA, **indexers_kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2183,29 +2343,34 @@ def reindex( align pandas.Index.get_indexer """ - indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, - 'reindex') + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") bad_dims = [d for d in indexers if d not in self.dims] if bad_dims: - raise ValueError('invalid reindex dimensions: %s' % bad_dims) + raise ValueError("invalid reindex dimensions: %s" % bad_dims) variables, indexes = alignment.reindex_variables( - self.variables, self.sizes, self.indexes, indexers, method, - tolerance, copy=copy, fill_value=fill_value) + self.variables, + self.sizes, + self.indexes, + indexers, + method, + tolerance, + copy=copy, + fill_value=fill_value, + ) coord_names = set(self._coord_names) coord_names.update(indexers) - return self._replace_with_new_dims( - variables, coord_names, indexes=indexes) + return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def interp( self, coords: Mapping[Hashable, Any] = None, - method: str = 'linear', + method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, **coords_kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """ Multidimensional interpolation of Dataset. Parameters @@ -2248,7 +2413,7 @@ def interp( if kwargs is None: kwargs = {} - coords = either_dict_or_kwargs(coords, coords_kwargs, 'interp') + coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") indexers = OrderedDict(self._validate_indexers(coords)) obj = self if assume_sorted else self.sortby([k for k in coords]) @@ -2264,36 +2429,42 @@ def _validate_interp_indexer(x, new_x): # In the case of datetimes, the restrictions placed on indexers # used with interp are stronger than those which are placed on # isel, so we need an additional check after _validate_indexers. - if (_contains_datetime_like_objects(x) - and not _contains_datetime_like_objects(new_x)): - raise TypeError('When interpolating over a datetime-like ' - 'coordinate, the coordinates to ' - 'interpolate to must be either datetime ' - 'strings or datetimes. ' - 'Instead got\n{}'.format(new_x)) + if _contains_datetime_like_objects( + x + ) and not _contains_datetime_like_objects(new_x): + raise TypeError( + "When interpolating over a datetime-like " + "coordinate, the coordinates to " + "interpolate to must be either datetime " + "strings or datetimes. " + "Instead got\n{}".format(new_x) + ) else: return (x, new_x) variables = OrderedDict() # type: OrderedDict[Hashable, Variable] for name, var in obj._variables.items(): if name not in indexers: - if var.dtype.kind in 'uifc': + if var.dtype.kind in "uifc": var_indexers = { k: _validate_interp_indexer(maybe_variable(obj, k), v) for k, v in indexers.items() if k in var.dims } variables[name] = missing.interp( - var, var_indexers, method, **kwargs) + var, var_indexers, method, **kwargs + ) elif all(d not in indexers for d in var.dims): # keep unrelated object array variables[name] = var coord_names = set(variables).intersection(obj._coord_names) indexes = OrderedDict( - (k, v) for k, v in obj.indexes.items() if k not in indexers) + (k, v) for k, v in obj.indexes.items() if k not in indexers + ) selected = self._replace_with_new_dims( - variables.copy(), coord_names, indexes=indexes) + variables.copy(), coord_names, indexes=indexes + ) # attach indexer as coordinate variables.update(indexers) @@ -2303,24 +2474,20 @@ def _validate_interp_indexer(x, new_x): indexes[k] = v.to_index() # Extract coordinates from indexers - coord_vars, new_indexes = ( - selected._get_indexers_coords_and_indexes(coords)) + coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) variables.update(coord_vars) indexes.update(new_indexes) - coord_names = (set(variables) - .intersection(obj._coord_names) - .union(coord_vars)) - return self._replace_with_new_dims( - variables, coord_names, indexes=indexes) + coord_names = set(variables).intersection(obj._coord_names).union(coord_vars) + return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def interp_like( self, - other: Union['Dataset', 'DataArray'], - method: str = 'linear', + other: Union["Dataset", "DataArray"], + method: str = "linear", assume_sorted: bool = False, - kwargs: Mapping[str, Any] = None - ) -> 'Dataset': + kwargs: Mapping[str, Any] = None, + ) -> "Dataset": """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -2366,7 +2533,7 @@ def interp_like( numeric_coords = OrderedDict() # type: OrderedDict[Hashable, pd.Index] object_coords = OrderedDict() # type: OrderedDict[Hashable, pd.Index] for k, v in coords.items(): - if v.dtype.kind in 'uifcMm': + if v.dtype.kind in "uifcMm": numeric_coords[k] = v else: object_coords[k] = v @@ -2387,7 +2554,7 @@ def _rename_vars(self, name_dict, dims_dict): var.dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) name = name_dict.get(k, k) if name in variables: - raise ValueError('the new name %r conflicts' % (name,)) + raise ValueError("the new name %r conflicts" % (name,)) variables[name] = var if k in self._coord_names: coord_names.add(name) @@ -2404,8 +2571,13 @@ def _rename_indexes(self, name_dict): new_name = name_dict.get(k, k) if isinstance(v, pd.MultiIndex): new_names = [name_dict.get(k, k) for k in v.names] - index = pd.MultiIndex(v.levels, v.labels, v.sortorder, - names=new_names, verify_integrity=False) + index = pd.MultiIndex( + v.levels, + v.labels, + v.sortorder, + names=new_names, + verify_integrity=False, + ) else: index = pd.Index(v, name=new_name) indexes[new_name] = index @@ -2422,7 +2594,7 @@ def rename( name_dict: Mapping[Hashable, Hashable] = None, inplace: bool = None, **names: Hashable - ) -> 'Dataset': + ) -> "Dataset": """Returns a new object with renamed variables and dimensions. Parameters @@ -2450,22 +2622,24 @@ def rename( DataArray.rename """ inplace = _check_inplace(inplace) - name_dict = either_dict_or_kwargs(name_dict, names, 'rename') + name_dict = either_dict_or_kwargs(name_dict, names, "rename") for k in name_dict.keys(): if k not in self and k not in self.dims: - raise ValueError("cannot rename %r because it is not a " - "variable or dimension in this dataset" % k) + raise ValueError( + "cannot rename %r because it is not a " + "variable or dimension in this dataset" % k + ) variables, coord_names, dims, indexes = self._rename_all( - name_dict=name_dict, dims_dict=name_dict) - return self._replace(variables, coord_names, dims=dims, - indexes=indexes, inplace=inplace) + name_dict=name_dict, dims_dict=name_dict + ) + return self._replace( + variables, coord_names, dims=dims, indexes=indexes, inplace=inplace + ) def rename_dims( - self, - dims_dict: Mapping[Hashable, Hashable] = None, - **dims: Hashable - ) -> 'Dataset': + self, dims_dict: Mapping[Hashable, Hashable] = None, **dims: Hashable + ) -> "Dataset": """Returns a new object with renamed dimensions only. Parameters @@ -2489,22 +2663,22 @@ def rename_dims( Dataset.rename_vars DataArray.rename """ - dims_dict = either_dict_or_kwargs(dims_dict, dims, 'rename_dims') + dims_dict = either_dict_or_kwargs(dims_dict, dims, "rename_dims") for k in dims_dict: if k not in self.dims: - raise ValueError("cannot rename %r because it is not a " - "dimension in this dataset" % k) + raise ValueError( + "cannot rename %r because it is not a " + "dimension in this dataset" % k + ) variables, coord_names, sizes, indexes = self._rename_all( - name_dict={}, dims_dict=dims_dict) - return self._replace( - variables, coord_names, dims=sizes, indexes=indexes) + name_dict={}, dims_dict=dims_dict + ) + return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self, - name_dict: Mapping[Hashable, Hashable] = None, - **names: Hashable - ) -> 'Dataset': + self, name_dict: Mapping[Hashable, Hashable] = None, **names: Hashable + ) -> "Dataset": """Returns a new object with renamed variables including coordinates Parameters @@ -2528,21 +2702,21 @@ def rename_vars( Dataset.rename_dims DataArray.rename """ - name_dict = either_dict_or_kwargs(name_dict, names, 'rename_vars') + name_dict = either_dict_or_kwargs(name_dict, names, "rename_vars") for k in name_dict: if k not in self: - raise ValueError("cannot rename %r because it is not a " - "variable or coordinate in this dataset" % k) + raise ValueError( + "cannot rename %r because it is not a " + "variable or coordinate in this dataset" % k + ) variables, coord_names, dims, indexes = self._rename_all( - name_dict=name_dict, dims_dict={}) - return self._replace(variables, coord_names, dims=dims, - indexes=indexes) + name_dict=name_dict, dims_dict={} + ) + return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self, - dims_dict: Mapping[Hashable, Hashable], - inplace: bool = None - ) -> 'Dataset': + self, dims_dict: Mapping[Hashable, Hashable], inplace: bool = None + ) -> "Dataset": """Returns a new object with swapped dimensions. Parameters @@ -2571,12 +2745,15 @@ def swap_dims( inplace = _check_inplace(inplace) for k, v in dims_dict.items(): if k not in self.dims: - raise ValueError('cannot swap from dimension %r because it is ' - 'not an existing dimension' % k) + raise ValueError( + "cannot swap from dimension %r because it is " + "not an existing dimension" % k + ) if self.variables[v].dims != (k,): - raise ValueError('replacement dimension %r is not a 1D ' - 'variable along the old dimension %r' - % (v, k)) + raise ValueError( + "replacement dimension %r is not a 1D " + "variable along the old dimension %r" % (v, k) + ) result_dims = {dims_dict.get(dim, dim) for dim in self.dims} @@ -2598,16 +2775,16 @@ def swap_dims( var.dims = dims variables[k] = var - return self._replace_with_new_dims(variables, coord_names, - indexes=indexes, inplace=inplace) + return self._replace_with_new_dims( + variables, coord_names, indexes=indexes, inplace=inplace + ) def expand_dims( self, - dim: Union[None, Hashable, Sequence[Hashable], - Mapping[Hashable, Any]] = None, + dim: Union[None, Hashable, Sequence[Hashable], Mapping[Hashable, Any]] = None, axis: Union[None, int, Sequence[int]] = None, **dim_kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -2654,7 +2831,7 @@ def expand_dims( """ # TODO: get rid of the below code block when python 3.5 is no longer # supported. - if sys.version < '3.6': + if sys.version < "3.6": if isinstance(dim, Mapping) and not isinstance(dim, OrderedDict): raise TypeError("dim must be an OrderedDict for python <3.6") if dim_kwargs: @@ -2674,10 +2851,10 @@ def expand_dims( dim = OrderedDict(((dim, 1),)) elif isinstance(dim, Sequence): if len(dim) != len(set(dim)): - raise ValueError('dims should not contain duplicate values.') + raise ValueError("dims should not contain duplicate values.") dim = OrderedDict((d, 1) for d in dim) - dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims') + dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") assert isinstance(dim, MutableMapping) if axis is None: @@ -2686,16 +2863,15 @@ def expand_dims( axis = [axis] if len(dim) != len(axis): - raise ValueError('lengths of dim and axis should be identical.') + raise ValueError("lengths of dim and axis should be identical.") for d in dim: if d in self.dims: + raise ValueError("Dimension {dim} already exists.".format(dim=d)) + if d in self._variables and not utils.is_scalar(self._variables[d]): raise ValueError( - 'Dimension {dim} already exists.'.format(dim=d)) - if (d in self._variables - and not utils.is_scalar(self._variables[d])): - raise ValueError( - '{dim} already exists as coordinate or' - ' variable name.'.format(dim=d)) + "{dim} already exists as coordinate or" + " variable name.".format(dim=d) + ) variables = OrderedDict() # type: OrderedDict[Hashable, Variable] coord_names = self._coord_names.copy() @@ -2713,8 +2889,10 @@ def expand_dims( elif isinstance(v, int): pass # Do nothing if the dimensions value is just an int else: - raise TypeError('The value of new dimension {k} must be ' - 'an iterable or an int'.format(k=k)) + raise TypeError( + "The value of new dimension {k} must be " + "an iterable or an int".format(k=k) + ) for k, v in self._variables.items(): if k not in dim: @@ -2725,15 +2903,15 @@ def expand_dims( for a in axis: if a < -result_ndim or result_ndim - 1 < a: raise IndexError( - 'Axis {a} is out of bounds of the expanded' - ' dimension size {dim}.'.format( - a=a, v=k, dim=result_ndim)) + "Axis {a} is out of bounds of the expanded" + " dimension size {dim}.".format( + a=a, v=k, dim=result_ndim + ) + ) - axis_pos = [a if a >= 0 else result_ndim + a - for a in axis] + axis_pos = [a if a >= 0 else result_ndim + a for a in axis] if len(axis_pos) != len(set(axis_pos)): - raise ValueError('axis should not contain duplicate' - ' values.') + raise ValueError("axis should not contain duplicate" " values.") # We need to sort them to make sure `axis` equals to the # axis positions of the result array. zip_axis_dim = sorted(zip(axis_pos, dim.items())) @@ -2751,7 +2929,8 @@ def expand_dims( new_dims.update(dim) return self._replace_vars_and_dims( - variables, dims=new_dims, coord_names=coord_names) + variables, dims=new_dims, coord_names=coord_names + ) def set_index( self, @@ -2759,7 +2938,7 @@ def set_index( append: bool = False, inplace: bool = None, **indexes_kwargs: Union[Hashable, Sequence[Hashable]] - ) -> 'Dataset': + ) -> "Dataset": """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -2790,19 +2969,20 @@ def set_index( Dataset.swap_dims """ inplace = _check_inplace(inplace) - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') - variables, coord_names = merge_indexes(indexes, self._variables, - self._coord_names, - append=append) - return self._replace_vars_and_dims(variables, coord_names=coord_names, - inplace=inplace) + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") + variables, coord_names = merge_indexes( + indexes, self._variables, self._coord_names, append=append + ) + return self._replace_vars_and_dims( + variables, coord_names=coord_names, inplace=inplace + ) def reset_index( self, dims_or_levels: Union[Hashable, Sequence[Hashable]], drop: bool = False, inplace: bool = None, - ) -> 'Dataset': + ) -> "Dataset": """Reset the specified index(es) or multi-index level(s). Parameters @@ -2834,15 +3014,16 @@ def reset_index( cast(Mapping[Hashable, Hashable], self._level_coords), drop=drop, ) - return self._replace_vars_and_dims(variables, coord_names=coord_names, - inplace=inplace) + return self._replace_vars_and_dims( + variables, coord_names=coord_names, inplace=inplace + ) def reorder_levels( self, dim_order: Mapping[Hashable, Sequence[int]] = None, inplace: bool = None, **dim_order_kwargs: Sequence[int] - ) -> 'Dataset': + ) -> "Dataset": """Rearrange index levels using input order. Parameters @@ -2865,8 +3046,7 @@ def reorder_levels( coordinates. """ inplace = _check_inplace(inplace) - dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, - 'reorder_levels') + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() indexes = OrderedDict(self.indexes) for dim, order in dim_order.items(): @@ -2896,30 +3076,30 @@ def _stack_once(self, dims, new_dim): # consider dropping levels that are unused? levels = [self.get_index(dim) for dim in dims] - if LooseVersion(pd.__version__) < LooseVersion('0.19.0'): + if LooseVersion(pd.__version__) < LooseVersion("0.19.0"): # RangeIndex levels in a MultiIndex are broken for appending in # pandas before v0.19.0 - levels = [pd.Int64Index(level) - if isinstance(level, pd.RangeIndex) - else level - for level in levels] + levels = [ + pd.Int64Index(level) if isinstance(level, pd.RangeIndex) else level + for level in levels + ] idx = utils.multiindex_from_product_levels(levels, names=dims) variables[new_dim] = IndexVariable(new_dim, idx) coord_names = set(self._coord_names) - set(dims) | {new_dim} - indexes = OrderedDict((k, v) for k, v in self.indexes.items() - if k not in dims) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k not in dims) indexes[new_dim] = idx return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes) + variables, coord_names=coord_names, indexes=indexes + ) def stack( self, dimensions: Mapping[Hashable, Sequence[Hashable]] = None, **dimensions_kwargs: Sequence[Hashable] - ) -> 'Dataset': + ) -> "Dataset": """ Stack any number of existing dimensions into a single new dimension. @@ -2944,8 +3124,7 @@ def stack( -------- Dataset.unstack """ - dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, - 'stack') + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "stack") result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -2955,9 +3134,9 @@ def to_stacked_array( self, new_dim: Hashable, sample_dims: Sequence[Hashable], - variable_dim: str = 'variable', - name: Hashable = None - ) -> 'DataArray': + variable_dim: str = "variable", + name: Hashable = None, + ) -> "DataArray": """Combine variables of differing dimensionality into a DataArray without broadcasting. @@ -3022,8 +3201,7 @@ def to_stacked_array( Dimensions without coordinates: x """ - stacking_dims = tuple(dim for dim in self.dims - if dim not in sample_dims) + stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims) for variable in self: dims = self[variable].dims @@ -3045,22 +3223,23 @@ def ensure_stackable(val): # must be list for .expand_dims expand_dims = list(expand_dims) - return (val.assign_coords(**assign_coords) - .expand_dims(expand_dims) - .stack({new_dim: (variable_dim,) + stacking_dims})) + return ( + val.assign_coords(**assign_coords) + .expand_dims(expand_dims) + .stack({new_dim: (variable_dim,) + stacking_dims}) + ) # concatenate the arrays - stackable_vars = [ensure_stackable(self[key]) - for key in self.data_vars] + stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] data_array = xr.concat(stackable_vars, dim=new_dim) # coerce the levels of the MultiIndex to have the same type as the # input dimensions. This code is messy, so it might be better to just # input a dummy value for the singleton dimension. idx = data_array.indexes[new_dim] - levels = ([idx.levels[0]] - + [level.astype(self[level.name].dtype) - for level in idx.levels[1:]]) + levels = [idx.levels[0]] + [ + level.astype(self[level.name].dtype) for level in idx.levels[1:] + ] new_idx = idx.set_levels(levels) data_array[new_dim] = IndexVariable(new_dim, new_idx) @@ -3069,7 +3248,7 @@ def ensure_stackable(val): return data_array - def _unstack_once(self, dim: Hashable) -> 'Dataset': + def _unstack_once(self, dim: Hashable) -> "Dataset": index = self.get_index(dim) # GH2619. For MultiIndex, we need to call remove_unused. if LooseVersion(pd.__version__) >= "0.20": @@ -3089,8 +3268,7 @@ def _unstack_once(self, dim: Hashable) -> 'Dataset': new_dim_sizes = [lev.size for lev in index.levels] variables = OrderedDict() # type: OrderedDict[Hashable, Variable] - indexes = OrderedDict( - (k, v) for k, v in self.indexes.items() if k != dim) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k != dim) for name, var in obj.variables.items(): if name != dim: @@ -3107,12 +3285,10 @@ def _unstack_once(self, dim: Hashable) -> 'Dataset': coord_names = set(self._coord_names) - {dim} | set(new_dim_names) return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes) + variables, coord_names=coord_names, indexes=indexes + ) - def unstack( - self, - dim: Union[Hashable, Iterable[Hashable]] = None - ) -> 'Dataset': + def unstack(self, dim: Union[Hashable, Iterable[Hashable]] = None) -> "Dataset": """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -3136,8 +3312,7 @@ def unstack( """ if dim is None: dims = [ - d for d in self.dims - if isinstance(self.get_index(d), pd.MultiIndex) + d for d in self.dims if isinstance(self.get_index(d), pd.MultiIndex) ] else: if isinstance(dim, str) or not isinstance(dim, Iterable): @@ -3147,21 +3322,25 @@ def unstack( missing_dims = [d for d in dims if d not in self.dims] if missing_dims: - raise ValueError('Dataset does not contain the dimensions: %s' - % missing_dims) + raise ValueError( + "Dataset does not contain the dimensions: %s" % missing_dims + ) - non_multi_dims = [d for d in dims if not - isinstance(self.get_index(d), pd.MultiIndex)] + non_multi_dims = [ + d for d in dims if not isinstance(self.get_index(d), pd.MultiIndex) + ] if non_multi_dims: - raise ValueError('cannot unstack dimensions that do not ' - 'have a MultiIndex: %s' % non_multi_dims) + raise ValueError( + "cannot unstack dimensions that do not " + "have a MultiIndex: %s" % non_multi_dims + ) result = self.copy(deep=False) for dim in dims: result = result._unstack_once(dim) return result - def update(self, other: 'DatasetLike', inplace: bool = None) -> 'Dataset': + def update(self, other: "DatasetLike", inplace: bool = None) -> "Dataset": """Update this dataset's variables with those from another dataset. Parameters @@ -3193,18 +3372,19 @@ def update(self, other: 'DatasetLike', inplace: bool = None) -> 'Dataset': inplace = _check_inplace(inplace, default=True) variables, coord_names, dims = dataset_update_method(self, other) - return self._replace_vars_and_dims(variables, coord_names, dims, - inplace=inplace) + return self._replace_vars_and_dims( + variables, coord_names, dims, inplace=inplace + ) def merge( self, - other: 'DatasetLike', + other: "DatasetLike", inplace: bool = None, overwrite_vars: Union[Hashable, Iterable[Hashable]] = frozenset(), - compat: str = 'no_conflicts', - join: str = 'outer', - fill_value: Any = dtypes.NA - ) -> 'Dataset': + compat: str = "no_conflicts", + join: str = "outer", + fill_value: Any = dtypes.NA, + ) -> "Dataset": """Merge the arrays of two datasets into a single dataset. This method generally does not allow for overriding data, with the @@ -3257,43 +3437,45 @@ def merge( """ inplace = _check_inplace(inplace) variables, coord_names, dims = dataset_merge_method( - self, other, overwrite_vars=overwrite_vars, compat=compat, - join=join, fill_value=fill_value) + self, + other, + overwrite_vars=overwrite_vars, + compat=compat, + join=join, + fill_value=fill_value, + ) - return self._replace_vars_and_dims(variables, coord_names, dims, - inplace=inplace) + return self._replace_vars_and_dims( + variables, coord_names, dims, inplace=inplace + ) - def _assert_all_in_dataset(self, names: Iterable[Hashable], - virtual_okay: bool = False) -> None: + def _assert_all_in_dataset( + self, names: Iterable[Hashable], virtual_okay: bool = False + ) -> None: bad_names = set(names) - set(self._variables) if virtual_okay: bad_names -= self.virtual_variables if bad_names: - raise ValueError('One or more of the specified variables ' - 'cannot be found in this dataset') + raise ValueError( + "One or more of the specified variables " + "cannot be found in this dataset" + ) # Drop variables @overload def drop( - self, - labels: Union[Hashable, Iterable[Hashable]], - *, - errors: str = 'raise' - ) -> 'Dataset': + self, labels: Union[Hashable, Iterable[Hashable]], *, errors: str = "raise" + ) -> "Dataset": ... # Drop index labels along dimension @overload # noqa: F811 def drop( - self, - labels: Any, # array-like - dim: Hashable, - *, - errors: str = 'raise' - ) -> 'Dataset': + self, labels: Any, dim: Hashable, *, errors: str = "raise" # array-like + ) -> "Dataset": ... - def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 + def drop(self, labels, dim=None, *, errors="raise"): # noqa: F811 """Drop variables or index labels from this dataset. Parameters @@ -3314,7 +3496,7 @@ def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 ------- dropped : Dataset """ - if errors not in ['raise', 'ignore']: + if errors not in ["raise", "ignore"]: raise ValueError('errors must be either "raise" or "ignore"') if dim is None: @@ -3334,33 +3516,26 @@ def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 try: index = self.indexes[dim] except KeyError: - raise ValueError( - 'dimension %r does not have coordinate labels' % dim) + raise ValueError("dimension %r does not have coordinate labels" % dim) new_index = index.drop(labels, errors=errors) return self.loc[{dim: new_index}] - def _drop_vars( - self, - names: set, - errors: str = 'raise' - ) -> 'Dataset': - if errors == 'raise': + def _drop_vars(self, names: set, errors: str = "raise") -> "Dataset": + if errors == "raise": self._assert_all_in_dataset(names) - variables = OrderedDict((k, v) for k, v in self._variables.items() - if k not in names) + variables = OrderedDict( + (k, v) for k, v in self._variables.items() if k not in names + ) coord_names = {k for k in self._coord_names if k in variables} - indexes = OrderedDict((k, v) for k, v in self.indexes.items() - if k not in names) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k not in names) return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes) + variables, coord_names=coord_names, indexes=indexes + ) def drop_dims( - self, - drop_dims: Union[Hashable, Iterable[Hashable]], - *, - errors: str = 'raise' - ) -> 'Dataset': + self, drop_dims: Union[Hashable, Iterable[Hashable]], *, errors: str = "raise" + ) -> "Dataset": """Drop dimensions and associated variables from this dataset. Parameters @@ -3383,7 +3558,7 @@ def drop_dims( in the dataset. If 'ignore', any given dimensions that are in the dataset are dropped and no error is raised. """ - if errors not in ['raise', 'ignore']: + if errors not in ["raise", "ignore"]: raise ValueError('errors must be either "raise" or "ignore"') if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable): @@ -3391,19 +3566,17 @@ def drop_dims( else: drop_dims = set(drop_dims) - if errors == 'raise': + if errors == "raise": missing_dims = drop_dims - set(self.dims) if missing_dims: - raise ValueError('Dataset does not contain the dimensions: %s' - % missing_dims) + raise ValueError( + "Dataset does not contain the dimensions: %s" % missing_dims + ) - drop_vars = { - k for k, v in self._variables.items() - if set(v.dims) & drop_dims - } + drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} return self._drop_vars(drop_vars) - def transpose(self, *dims: Hashable) -> 'Dataset': + def transpose(self, *dims: Hashable) -> "Dataset": """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -3434,9 +3607,10 @@ def transpose(self, *dims: Hashable) -> 'Dataset': """ if dims: if set(dims) ^ set(self.dims): - raise ValueError('arguments to transpose (%s) must be ' - 'permuted dataset dimensions (%s)' - % (dims, tuple(self.dims))) + raise ValueError( + "arguments to transpose (%s) must be " + "permuted dataset dimensions (%s)" % (dims, tuple(self.dims)) + ) ds = self.copy() for name, var in self._variables.items(): var_dims = tuple(dim for dim in dims if dim in var.dims) @@ -3446,9 +3620,9 @@ def transpose(self, *dims: Hashable) -> 'Dataset': def dropna( self, dim: Hashable, - how: str = 'any', + how: str = "any", thresh: int = None, - subset: Iterable[Hashable] = None + subset: Iterable[Hashable] = None, ): """Returns a new dataset with dropped labels for missing values along the provided dimension. @@ -3476,7 +3650,7 @@ def dropna( # depending on the order of the supplied axes. if dim not in self.dims: - raise ValueError('%s must be a single dataset dimension' % dim) + raise ValueError("%s must be a single dataset dimension" % dim) if subset is None: subset = iter(self.data_vars) @@ -3493,18 +3667,18 @@ def dropna( if thresh is not None: mask = count >= thresh - elif how == 'any': + elif how == "any": mask = count == size - elif how == 'all': + elif how == "all": mask = count > 0 elif how is not None: - raise ValueError('invalid how option: %s' % how) + raise ValueError("invalid how option: %s" % how) else: - raise TypeError('must specify how or thresh') + raise TypeError("must specify how or thresh") return self.isel({dim: mask}) - def fillna(self, value: Any) -> 'Dataset': + def fillna(self, value: Any) -> "Dataset": """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3526,21 +3700,23 @@ def fillna(self, value: Any) -> 'Dataset': Dataset """ if utils.is_dict_like(value): - value_keys = getattr(value, 'data_vars', value).keys() + value_keys = getattr(value, "data_vars", value).keys() if not set(value_keys) <= set(self.data_vars.keys()): - raise ValueError('all variables in the argument to `fillna` ' - 'must be contained in the original dataset') + raise ValueError( + "all variables in the argument to `fillna` " + "must be contained in the original dataset" + ) out = ops.fillna(self, value) return out def interpolate_na( self, dim: Hashable = None, - method: str = 'linear', + method: str = "linear", limit: int = None, use_coordinate: Union[bool, Hashable] = True, **kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Interpolate values according to different methods. Parameters @@ -3583,13 +3759,18 @@ def interpolate_na( """ from .missing import interp_na, _apply_over_vars_with_dim - new = _apply_over_vars_with_dim(interp_na, self, dim=dim, - method=method, limit=limit, - use_coordinate=use_coordinate, - **kwargs) + new = _apply_over_vars_with_dim( + interp_na, + self, + dim=dim, + method=method, + limit=limit, + use_coordinate=use_coordinate, + **kwargs + ) return new - def ffill(self, dim: Hashable, limit: int = None) -> 'Dataset': + def ffill(self, dim: Hashable, limit: int = None) -> "Dataset": """Fill NaN values by propogating values forward *Requires bottleneck.* @@ -3614,7 +3795,7 @@ def ffill(self, dim: Hashable, limit: int = None) -> 'Dataset': new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self, dim: Hashable, limit: int = None) -> 'Dataset': + def bfill(self, dim: Hashable, limit: int = None) -> "Dataset": """Fill NaN values by propogating values backward *Requires bottleneck.* @@ -3639,7 +3820,7 @@ def bfill(self, dim: Hashable, limit: int = None) -> 'Dataset': new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self, other: 'Dataset') -> 'Dataset': + def combine_first(self, other: "Dataset") -> "Dataset": """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -3667,7 +3848,7 @@ def reduce( numeric_only: bool = False, allow_lazy: bool = False, **kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Reduce this dataset by applying `func` along some dimension(s). Parameters @@ -3709,25 +3890,25 @@ def reduce( missing_dimensions = [d for d in dims if d not in self.dims] if missing_dimensions: - raise ValueError('Dataset does not contain the dimensions: %s' - % missing_dimensions) + raise ValueError( + "Dataset does not contain the dimensions: %s" % missing_dimensions + ) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) variables = OrderedDict() # type: OrderedDict[Hashable, Variable] for name, var in self._variables.items(): - reduce_dims = [ - d for d in var.dims - if d in dims - ] + reduce_dims = [d for d in var.dims if d in dims] if name in self.coords: if not reduce_dims: variables[name] = var else: - if (not numeric_only - or np.issubdtype(var.dtype, np.number) - or (var.dtype == np.bool_)): + if ( + not numeric_only + or np.issubdtype(var.dtype, np.number) + or (var.dtype == np.bool_) + ): if len(reduce_dims) == 1: # unpack dimensions for the benefit of functions # like np.argmin which can't handle tuple arguments @@ -3737,18 +3918,21 @@ def reduce( # axis=(0, 1) if they will be equivalent, because # the former is often more efficient reduce_dims = None # type: ignore - variables[name] = var.reduce(func, dim=reduce_dims, - keep_attrs=keep_attrs, - keepdims=keepdims, - allow_lazy=allow_lazy, - **kwargs) + variables[name] = var.reduce( + func, + dim=reduce_dims, + keep_attrs=keep_attrs, + keepdims=keepdims, + allow_lazy=allow_lazy, + **kwargs + ) coord_names = {k for k in self.coords if k in variables} - indexes = OrderedDict((k, v) for k, v in self.indexes.items() - if k in variables) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k in variables) attrs = self.attrs if keep_attrs else None return self._replace_with_new_dims( - variables, coord_names=coord_names, attrs=attrs, indexes=indexes) + variables, coord_names=coord_names, attrs=attrs, indexes=indexes + ) def apply( self, @@ -3756,7 +3940,7 @@ def apply( keep_attrs: bool = None, args: Iterable[Any] = (), **kwargs: Any - ) -> 'Dataset': + ) -> "Dataset": """Apply a function over the data variables in this dataset. Parameters @@ -3800,17 +3984,16 @@ def apply( """ # noqa variables = OrderedDict( (k, maybe_wrap_array(v, func(v, *args, **kwargs))) - for k, v in self.data_vars.items()) + for k, v in self.data_vars.items() + ) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) def assign( - self, - variables: Mapping[Hashable, Any] = None, - **variables_kwargs: Hashable - ) -> 'Dataset': + self, variables: Mapping[Hashable, Any] = None, **variables_kwargs: Hashable + ) -> "Dataset": """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -3843,8 +4026,7 @@ def assign( -------- pandas.DataFrame.assign """ - variables = either_dict_or_kwargs( - variables, variables_kwargs, 'assign') + variables = either_dict_or_kwargs(variables, variables_kwargs, "assign") data = self.copy() # do all calculations first... results = data._calc_assign_results(variables) @@ -3852,7 +4034,7 @@ def assign( data.update(results) return data - def to_array(self, dim='variable', name=None): + def to_array(self, dim="variable", name=None): """Convert this dataset into an xarray.DataArray The data variables of this dataset will be broadcast against each other @@ -3885,8 +4067,10 @@ def to_array(self, dim='variable', name=None): def _to_dataframe(self, ordered_dims): columns = [k for k in self.variables if k not in self.dims] - data = [self._variables[k].set_dims(ordered_dims).values.reshape(-1) - for k in columns] + data = [ + self._variables[k].set_dims(ordered_dims).values.reshape(-1) + for k in columns + ] index = self.coords.to_index(ordered_dims) return pd.DataFrame(OrderedDict(zip(columns, data)), index=index) @@ -3916,8 +4100,7 @@ def from_dataframe(cls, dataframe): # even if some variables have different dimensionality. if not dataframe.columns.is_unique: - raise ValueError( - 'cannot convert DataFrame with non-unique columns') + raise ValueError("cannot convert DataFrame with non-unique columns") idx = dataframe.index obj = cls() @@ -3927,13 +4110,15 @@ def from_dataframe(cls, dataframe): # expand the DataFrame to include the product of all levels full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names) dataframe = dataframe.reindex(full_idx) - dims = [name if name is not None else 'level_%i' % n - for n, name in enumerate(idx.names)] + dims = [ + name if name is not None else "level_%i" % n + for n, name in enumerate(idx.names) + ] for dim, lev in zip(dims, idx.levels): obj[dim] = (dim, lev) shape = [lev.size for lev in idx.levels] else: - dims = (idx.name if idx.name is not None else 'index',) + dims = (idx.name if idx.name is not None else "index",) obj[dims[0]] = (dims, idx) shape = -1 @@ -3978,8 +4163,9 @@ def to_dask_dataframe(self, dim_order=None, set_index=False): dim_order = list(self.dims) elif set(dim_order) != set(self.dims): raise ValueError( - 'dim_order {} does not match the set of dimensions on this ' - 'Dataset: {}'.format(dim_order, list(self.dims))) + "dim_order {} does not match the set of dimensions on this " + "Dataset: {}".format(dim_order, list(self.dims)) + ) ordered_dims = OrderedDict((k, self.dims[k]) for k in dim_order) @@ -4037,12 +4223,16 @@ def to_dict(self, data=True): -------- Dataset.from_dict """ - d = {'coords': {}, 'attrs': decode_numpy_dict_values(self.attrs), - 'dims': dict(self.dims), 'data_vars': {}} + d = { + "coords": {}, + "attrs": decode_numpy_dict_values(self.attrs), + "dims": dict(self.dims), + "data_vars": {}, + } for k in self.coords: - d['coords'].update({k: self[k].variable.to_dict(data=data)}) + d["coords"].update({k: self[k].variable.to_dict(data=data)}) for k in self.data_vars: - d['data_vars'].update({k: self[k].variable.to_dict(data=data)}) + d["data_vars"].update({k: self[k].variable.to_dict(data=data)}) return d @classmethod @@ -4082,28 +4272,30 @@ def from_dict(cls, d): DataArray.from_dict """ - if not {'coords', 'data_vars'}.issubset(set(d)): + if not {"coords", "data_vars"}.issubset(set(d)): variables = d.items() else: import itertools - variables = itertools.chain(d.get('coords', {}).items(), - d.get('data_vars', {}).items()) + + variables = itertools.chain( + d.get("coords", {}).items(), d.get("data_vars", {}).items() + ) try: - variable_dict = OrderedDict([(k, (v['dims'], - v['data'], - v.get('attrs'))) for - k, v in variables]) + variable_dict = OrderedDict( + [(k, (v["dims"], v["data"], v.get("attrs"))) for k, v in variables] + ) except KeyError as e: raise ValueError( "cannot convert dict without the key " - "'{dims_data}'".format(dims_data=str(e.args[0]))) + "'{dims_data}'".format(dims_data=str(e.args[0])) + ) obj = cls(variable_dict) # what if coords aren't dims? - coords = set(d.get('coords', {})) - set(d.get('dims', {})) + coords = set(d.get("coords", {})) - set(d.get("dims", {})) obj = obj.set_coords(coords) - obj.attrs.update(d.get('attrs', {})) + obj.attrs.update(d.get("attrs", {})) return obj @@ -4130,7 +4322,7 @@ def func(self, other): if isinstance(other, groupby.GroupBy): return NotImplemented - align_type = OPTIONS['arithmetic_join'] if join is None else join + align_type = OPTIONS["arithmetic_join"] if join is None else join if isinstance(other, (DataArray, Dataset)): self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) @@ -4146,28 +4338,35 @@ def func(self, other): from .dataarray import DataArray if isinstance(other, groupby.GroupBy): - raise TypeError('in-place operations between a Dataset and ' - 'a grouped object are not permitted') + raise TypeError( + "in-place operations between a Dataset and " + "a grouped object are not permitted" + ) # we don't actually modify arrays in-place with in-place Dataset # arithmetic -- this lets us automatically align things if isinstance(other, (DataArray, Dataset)): other = other.reindex_like(self, copy=False) g = ops.inplace_to_noninplace_op(f) ds = self._calculate_binary_op(g, other, inplace=True) - self._replace_with_new_dims(ds._variables, ds._coord_names, - attrs=ds._attrs, indexes=ds._indexes, - inplace=True) + self._replace_with_new_dims( + ds._variables, + ds._coord_names, + attrs=ds._attrs, + indexes=ds._indexes, + inplace=True, + ) return self return func - def _calculate_binary_op(self, f, other, join='inner', - inplace=False): + def _calculate_binary_op(self, f, other, join="inner", inplace=False): def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): if inplace and set(lhs_data_vars) != set(rhs_data_vars): - raise ValueError('datasets must have the same data variables ' - 'for in-place arithmetic operations: %s, %s' - % (list(lhs_data_vars), list(rhs_data_vars))) + raise ValueError( + "datasets must have the same data variables " + "for in-place arithmetic operations: %s, %s" + % (list(lhs_data_vars), list(rhs_data_vars)) + ) dest_vars = OrderedDict() @@ -4184,20 +4383,23 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): if utils.is_dict_like(other) and not isinstance(other, Dataset): # can't use our shortcut of doing the binary operation with # Variable objects, so apply over our data vars instead. - new_data_vars = apply_over_both(self.data_vars, other, - self.data_vars, other) + new_data_vars = apply_over_both( + self.data_vars, other, self.data_vars, other + ) return Dataset(new_data_vars) - other_coords = getattr(other, 'coords', None) + other_coords = getattr(other, "coords", None) ds = self.coords.merge(other_coords) if isinstance(other, Dataset): - new_vars = apply_over_both(self.data_vars, other.data_vars, - self.variables, other.variables) + new_vars = apply_over_both( + self.data_vars, other.data_vars, self.variables, other.variables + ) else: - other_variable = getattr(other, 'variable', other) - new_vars = OrderedDict((k, f(self.variables[k], other_variable)) - for k in self.data_vars) + other_variable = getattr(other, "variable", other) + new_vars = OrderedDict( + (k, f(self.variables[k], other_variable)) for k in self.data_vars + ) ds._variables.update(new_vars) ds._dims = calculate_dimensions(ds._variables) return ds @@ -4208,7 +4410,7 @@ def _copy_attrs_from(self, other): if v in self.variables: self.variables[v].attrs = other.variables[v].attrs - def diff(self, dim, n=1, label='upper'): + def diff(self, dim, n=1, label="upper"): """Calculate the n-th order discrete difference along given axis. Parameters @@ -4253,30 +4455,28 @@ def diff(self, dim, n=1, label='upper'): if n == 0: return self if n < 0: - raise ValueError( - 'order `n` must be non-negative but got {}'.format(n) - ) + raise ValueError("order `n` must be non-negative but got {}".format(n)) # prepare slices kwargs_start = {dim: slice(None, -1)} kwargs_end = {dim: slice(1, None)} # prepare new coordinate - if label == 'upper': + if label == "upper": kwargs_new = kwargs_end - elif label == 'lower': + elif label == "lower": kwargs_new = kwargs_start else: - raise ValueError('The \'label\' argument has to be either ' - '\'upper\' or \'lower\'') + raise ValueError( + "The 'label' argument has to be either " "'upper' or 'lower'" + ) variables = OrderedDict() for name, var in self.variables.items(): if dim in var.dims: if name in self.data_vars: - variables[name] = (var.isel(**kwargs_end) - - var.isel(**kwargs_start)) + variables[name] = var.isel(**kwargs_end) - var.isel(**kwargs_start) else: variables[name] = var.isel(**kwargs_new) else: @@ -4333,7 +4533,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): Data variables: foo (x) object nan nan 'a' 'b' 'c' """ - shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") invalid = [k for k in shifts if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) @@ -4341,10 +4541,8 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): variables = OrderedDict() for name, var in self.variables.items(): if name in self.data_vars: - var_shifts = {k: v for k, v in shifts.items() - if k in var.dims} - variables[name] = var.shift( - fill_value=fill_value, shifts=var_shifts) + var_shifts = {k: v for k, v in shifts.items() if k in var.dims} + variables[name] = var.shift(fill_value=fill_value, shifts=var_shifts) else: variables[name] = var @@ -4394,15 +4592,18 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): Data variables: foo (x) object 'd' 'e' 'a' 'b' 'c' """ - shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll") invalid = [k for k in shifts if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) if roll_coords is None: - warnings.warn("roll_coords will be set to False in the future." - " Explicitly set roll_coords to silence warning.", - FutureWarning, stacklevel=2) + warnings.warn( + "roll_coords will be set to False in the future." + " Explicitly set roll_coords to silence warning.", + FutureWarning, + stacklevel=2, + ) roll_coords = True unrolled_vars = () if roll_coords else self.coords @@ -4410,8 +4611,9 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): variables = OrderedDict() for k, v in self.variables.items(): if k not in unrolled_vars: - variables[k] = v.roll(**{k: s for k, s in shifts.items() - if k in v.dims}) + variables[k] = v.roll( + **{k: s for k, s in shifts.items() if k in v.dims} + ) else: variables[k] = v @@ -4464,20 +4666,21 @@ def sortby(self, variables, ascending=True): variables = [variables] else: variables = variables - variables = [v if isinstance(v, DataArray) else self[v] - for v in variables] - aligned_vars = align(self, *variables, join='left') + variables = [v if isinstance(v, DataArray) else self[v] for v in variables] + aligned_vars = align(self, *variables, join="left") aligned_self = aligned_vars[0] aligned_other_vars = aligned_vars[1:] vars_by_dim = defaultdict(list) for data_array in aligned_other_vars: if data_array.ndim != 1: raise ValueError("Input DataArray is not 1-D.") - if (data_array.dtype == object - and LooseVersion(np.__version__) < LooseVersion('1.11.0')): + if data_array.dtype == object and LooseVersion( + np.__version__ + ) < LooseVersion("1.11.0"): raise NotImplementedError( - 'sortby uses np.lexsort under the hood, which requires ' - 'numpy 1.11.0 or later to support object data-type.') + "sortby uses np.lexsort under the hood, which requires " + "numpy 1.11.0 or later to support object data-type." + ) (key,) = data_array.dims vars_by_dim[key].append(data_array) @@ -4487,8 +4690,9 @@ def sortby(self, variables, ascending=True): indices[key] = order if ascending else order[::-1] return aligned_self.isel(**indices) - def quantile(self, q, dim=None, interpolation='linear', - numeric_only=False, keep_attrs=None): + def quantile( + self, q, dim=None, interpolation="linear", numeric_only=False, keep_attrs=None + ): """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements for each variable @@ -4540,8 +4744,10 @@ def quantile(self, q, dim=None, interpolation='linear', else: dims = set(dim) - _assert_empty([d for d in dims if d not in self.dims], - 'Dataset does not contain the dimensions: %s') + _assert_empty( + [d for d in dims if d not in self.dims], + "Dataset does not contain the dimensions: %s", + ) q = np.asarray(q, dtype=np.float64) @@ -4550,34 +4756,36 @@ def quantile(self, q, dim=None, interpolation='linear', reduce_dims = [d for d in var.dims if d in dims] if reduce_dims or not var.dims: if name not in self.coords: - if (not numeric_only + if ( + not numeric_only or np.issubdtype(var.dtype, np.number) - or var.dtype == np.bool_): + or var.dtype == np.bool_ + ): if len(reduce_dims) == var.ndim: # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient reduce_dims = None variables[name] = var.quantile( - q, dim=reduce_dims, interpolation=interpolation) + q, dim=reduce_dims, interpolation=interpolation + ) else: variables[name] = var # construct the new dataset coord_names = {k for k in self.coords if k in variables} - indexes = OrderedDict( - (k, v) for k, v in self.indexes.items() if k in variables - ) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k in variables) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None new = self._replace_with_new_dims( - variables, coord_names=coord_names, attrs=attrs, indexes=indexes) - if 'quantile' in new.dims: - new.coords['quantile'] = Variable('quantile', q) + variables, coord_names=coord_names, attrs=attrs, indexes=indexes + ) + if "quantile" in new.dims: + new.coords["quantile"] = Variable("quantile", q) else: - new.coords['quantile'] = q + new.coords["quantile"] = q return new def rank(self, dim, pct=False, keep_attrs=None): @@ -4609,8 +4817,7 @@ def rank(self, dim, pct=False, keep_attrs=None): Variables that do not depend on `dim` are dropped. """ if dim not in self.dims: - raise ValueError( - 'Dataset does not contain the dimension: %s' % dim) + raise ValueError("Dataset does not contain the dimension: %s" % dim) variables = OrderedDict() for name, var in self.variables.items(): @@ -4655,30 +4862,31 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): from .variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError('Coordinate {} does not exist.'.format(coord)) + raise ValueError("Coordinate {} does not exist.".format(coord)) coord_var = self[coord].variable if coord_var.ndim != 1: - raise ValueError('Coordinate {} must be 1 dimensional but is {}' - ' dimensional'.format(coord, coord_var.ndim)) + raise ValueError( + "Coordinate {} must be 1 dimensional but is {}" + " dimensional".format(coord, coord_var.ndim) + ) dim = coord_var.dims[0] if _contains_datetime_like_objects(coord_var): - if coord_var.dtype.kind in 'mM' and datetime_unit is None: + if coord_var.dtype.kind in "mM" and datetime_unit is None: datetime_unit, _ = np.datetime_data(coord_var.dtype) elif datetime_unit is None: - datetime_unit = 's' # Default to seconds for cftime objects + datetime_unit = "s" # Default to seconds for cftime objects coord_var = coord_var._to_numeric(datetime_unit=datetime_unit) variables = OrderedDict() for k, v in self.variables.items(): - if (k in self.data_vars and dim in v.dims - and k not in self.coords): + if k in self.data_vars and dim in v.dims and k not in self.coords: if _contains_datetime_like_objects(v): v = v._to_numeric(datetime_unit=datetime_unit) grad = duck_array_ops.gradient( - v.data, coord_var, edge_order=edge_order, - axis=v.get_axis_num(dim)) + v.data, coord_var, edge_order=edge_order, axis=v.get_axis_num(dim) + ) variables[k] = Variable(v.dims, grad) else: variables[k] = v @@ -4710,7 +4918,7 @@ def integrate(self, coord, datetime_unit=None): numpy.trapz: corresponding numpy function """ if not isinstance(coord, (list, tuple)): - coord = (coord, ) + coord = (coord,) result = self for c in coord: result = result._integrate_one(c, datetime_unit=datetime_unit) @@ -4720,21 +4928,22 @@ def _integrate_one(self, coord, datetime_unit=None): from .variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError('Coordinate {} does not exist.'.format(coord)) + raise ValueError("Coordinate {} does not exist.".format(coord)) coord_var = self[coord].variable if coord_var.ndim != 1: - raise ValueError('Coordinate {} must be 1 dimensional but is {}' - ' dimensional'.format(coord, coord_var.ndim)) + raise ValueError( + "Coordinate {} must be 1 dimensional but is {}" + " dimensional".format(coord, coord_var.ndim) + ) dim = coord_var.dims[0] if _contains_datetime_like_objects(coord_var): - if coord_var.dtype.kind in 'mM' and datetime_unit is None: + if coord_var.dtype.kind in "mM" and datetime_unit is None: datetime_unit, _ = np.datetime_data(coord_var.dtype) elif datetime_unit is None: - datetime_unit = 's' # Default to seconds for cftime objects - coord_var = datetime_to_numeric( - coord_var, datetime_unit=datetime_unit) + datetime_unit = "s" # Default to seconds for cftime objects + coord_var = datetime_to_numeric(coord_var, datetime_unit=datetime_unit) variables = OrderedDict() coord_names = set() @@ -4748,27 +4957,25 @@ def _integrate_one(self, coord, datetime_unit=None): if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim)) + v.data, coord_var.data, axis=v.get_axis_num(dim) + ) v_dims = list(v.dims) v_dims.remove(dim) variables[k] = Variable(v_dims, integ) else: variables[k] = v - indexes = OrderedDict( - (k, v) for k, v in self.indexes.items() if k in variables - ) + indexes = OrderedDict((k, v) for k, v in self.indexes.items() if k in variables) return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes) + variables, coord_names=coord_names, indexes=indexes + ) @property def real(self): - return self._unary_op(lambda x: x.real, - keep_attrs=True)(self) + return self._unary_op(lambda x: x.real, keep_attrs=True)(self) @property def imag(self): - return self._unary_op(lambda x: x.imag, - keep_attrs=True)(self) + return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) @property def plot(self): @@ -4864,8 +5071,7 @@ def filter_by_attrs(self, **kwargs): has_value_flag = False for attr_name, pattern in kwargs.items(): attr_value = variable.attrs.get(attr_name) - if ((callable(pattern) and pattern(attr_value)) - or attr_value == pattern): + if (callable(pattern) and pattern(attr_value)) or attr_value == pattern: has_value_flag = True else: has_value_flag = False diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 794f2b62183..4db2990accc 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -5,7 +5,7 @@ from . import utils # Use as a sentinel value to indicate a dtype appropriate NA value. -NA = utils.ReprObject('') +NA = utils.ReprObject("") @functools.total_ordering @@ -61,7 +61,7 @@ def maybe_promote(dtype): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer - fill_value = np.timedelta64('NaT') + fill_value = np.timedelta64("NaT") elif np.issubdtype(dtype, np.integer): if dtype.itemsize <= 2: dtype = np.float32 @@ -71,14 +71,14 @@ def maybe_promote(dtype): elif np.issubdtype(dtype, np.complexfloating): fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): - fill_value = np.datetime64('NaT') + fill_value = np.datetime64("NaT") else: dtype = object fill_value = np.nan return np.dtype(dtype), fill_value -NAT_TYPES = (np.datetime64('NaT'), np.timedelta64('NaT')) +NAT_TYPES = (np.datetime64("NaT"), np.timedelta64("NaT")) def get_fill_value(dtype): @@ -139,8 +139,7 @@ def get_neg_infinity(dtype): def is_datetime_like(dtype): """Check if a dtype is a subclass of the numpy datetime types """ - return (np.issubdtype(dtype, np.datetime64) or - np.issubdtype(dtype, np.timedelta64)) + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) def result_type(*arrays_and_dtypes): @@ -162,8 +161,9 @@ def result_type(*arrays_and_dtypes): types = {np.result_type(t).type for t in arrays_and_dtypes} for left, right in PROMOTE_TO_OBJECT: - if (any(issubclass(t, left) for t in types) and - any(issubclass(t, right) for t in types)): + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): return np.dtype(object) return np.result_type(*arrays_and_dtypes) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index f78ecb969a1..b2b6077c67f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -23,11 +23,17 @@ dask_array_compat = None # type: ignore -def _dask_or_eager_func(name, eager_module=np, dask_module=dask_array, - list_of_args=False, array_args=slice(1), - requires_dask=None): +def _dask_or_eager_func( + name, + eager_module=np, + dask_module=dask_array, + list_of_args=False, + array_args=slice(1), + requires_dask=None, +): """Create a function that dispatches to dask for dask array inputs.""" if dask_module is not None: + def f(*args, **kwargs): if list_of_args: dispatch_args = args[0] @@ -37,21 +43,23 @@ def f(*args, **kwargs): try: wrapped = getattr(dask_module, name) except AttributeError as e: - raise AttributeError("%s: requires dask >=%s" % - (e, requires_dask)) + raise AttributeError("%s: requires dask >=%s" % (e, requires_dask)) else: wrapped = getattr(eager_module, name) return wrapped(*args, **kwargs) + else: + def f(*args, **kwargs): return getattr(eager_module, name)(*args, **kwargs) + return f def fail_on_dask_array_input(values, msg=None, func_name=None): if isinstance(values, dask_array_type): if msg is None: - msg = '%r is not yet a valid method on dask arrays' + msg = "%r is not yet a valid method on dask arrays" if func_name is None: func_name = inspect.stack()[1][3] raise NotImplementedError(msg % func_name) @@ -61,22 +69,23 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): # https://github.com/dask/dask/pull/4822 moveaxis = npcompat.moveaxis -around = _dask_or_eager_func('around') -isclose = _dask_or_eager_func('isclose') +around = _dask_or_eager_func("around") +isclose = _dask_or_eager_func("isclose") -if hasattr(np, 'isnat') and ( - dask_array is None or hasattr(dask_array_type, '__array_ufunc__')): +if hasattr(np, "isnat") and ( + dask_array is None or hasattr(dask_array_type, "__array_ufunc__") +): # np.isnat is available since NumPy 1.13, so __array_ufunc__ is always # supported. isnat = np.isnat else: - isnat = _dask_or_eager_func('isnull', eager_module=pd) -isnan = _dask_or_eager_func('isnan') -zeros_like = _dask_or_eager_func('zeros_like') + isnat = _dask_or_eager_func("isnull", eager_module=pd) +isnan = _dask_or_eager_func("isnan") +zeros_like = _dask_or_eager_func("zeros_like") -pandas_isnull = _dask_or_eager_func('isnull', eager_module=pd) +pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd) def isnull(data): @@ -90,9 +99,7 @@ def isnull(data): elif issubclass(scalar_type, np.inexact): # float types use NaN for null return isnan(data) - elif issubclass( - scalar_type, (np.bool_, np.integer, np.character, np.void) - ): + elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): # these types cannot represent missing values return zeros_like(data, dtype=bool) else: @@ -111,52 +118,53 @@ def notnull(data): return ~isnull(data) -transpose = _dask_or_eager_func('transpose') -_where = _dask_or_eager_func('where', array_args=slice(3)) -isin = _dask_or_eager_func('isin', eager_module=npcompat, - dask_module=dask_array_compat, array_args=slice(2)) -take = _dask_or_eager_func('take') -broadcast_to = _dask_or_eager_func('broadcast_to') +transpose = _dask_or_eager_func("transpose") +_where = _dask_or_eager_func("where", array_args=slice(3)) +isin = _dask_or_eager_func( + "isin", eager_module=npcompat, dask_module=dask_array_compat, array_args=slice(2) +) +take = _dask_or_eager_func("take") +broadcast_to = _dask_or_eager_func("broadcast_to") -_concatenate = _dask_or_eager_func('concatenate', list_of_args=True) -_stack = _dask_or_eager_func('stack', list_of_args=True) +_concatenate = _dask_or_eager_func("concatenate", list_of_args=True) +_stack = _dask_or_eager_func("stack", list_of_args=True) -array_all = _dask_or_eager_func('all') -array_any = _dask_or_eager_func('any') +array_all = _dask_or_eager_func("all") +array_any = _dask_or_eager_func("any") -tensordot = _dask_or_eager_func('tensordot', array_args=slice(2)) -einsum = _dask_or_eager_func('einsum', array_args=slice(1, None), - requires_dask='0.17.3') +tensordot = _dask_or_eager_func("tensordot", array_args=slice(2)) +einsum = _dask_or_eager_func( + "einsum", array_args=slice(1, None), requires_dask="0.17.3" +) def gradient(x, coord, axis, edge_order): if isinstance(x, dask_array_type): - return dask_array_compat.gradient( - x, coord, axis=axis, edge_order=edge_order) + return dask_array_compat.gradient(x, coord, axis=axis, edge_order=edge_order) return npcompat.gradient(x, coord, axis=axis, edge_order=edge_order) def trapz(y, x, axis): if axis < 0: axis = y.ndim + axis - x_sl1 = (slice(1, None), ) + (None, ) * (y.ndim - axis - 1) - x_sl2 = (slice(None, -1), ) + (None, ) * (y.ndim - axis - 1) - slice1 = (slice(None),) * axis + (slice(1, None), ) - slice2 = (slice(None),) * axis + (slice(None, -1), ) - dx = (x[x_sl1] - x[x_sl2]) + x_sl1 = (slice(1, None),) + (None,) * (y.ndim - axis - 1) + x_sl2 = (slice(None, -1),) + (None,) * (y.ndim - axis - 1) + slice1 = (slice(None),) * axis + (slice(1, None),) + slice2 = (slice(None),) * axis + (slice(None, -1),) + dx = x[x_sl1] - x[x_sl2] integrand = dx * 0.5 * (y[tuple(slice1)] + y[tuple(slice2)]) return sum(integrand, axis=axis, skipna=False) masked_invalid = _dask_or_eager_func( - 'masked_invalid', eager_module=np.ma, - dask_module=getattr(dask_array, 'ma', None)) + "masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None) +) def asarray(data): return ( - data if (isinstance(data, dask_array_type) - or hasattr(data, '__array_function__')) + data + if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__")) else np.asarray(data) ) @@ -177,6 +185,7 @@ def as_like_arrays(*data): return data elif any(isinstance(d, sparse_array_type) for d in data): from sparse import COO + return tuple(COO(d) for d in data) else: return tuple(np.asarray(d) for d in data) @@ -188,8 +197,7 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): arr1, arr2 = as_like_arrays(arr1, arr2) if arr1.shape != arr2.shape: return False - return bool( - isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) def array_equiv(arr1, arr2): @@ -200,10 +208,10 @@ def array_equiv(arr1, arr2): return False with warnings.catch_warnings(): - warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") - flag_array = (arr1 == arr2) - flag_array |= (isnull(arr1) & isnull(arr2)) + flag_array = arr1 == arr2 + flag_array |= isnull(arr1) & isnull(arr2) return bool(flag_array.all()) @@ -217,9 +225,9 @@ def array_notnull_equiv(arr1, arr2): return False with warnings.catch_warnings(): - warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") - flag_array = (arr1 == arr2) + flag_array = arr1 == arr2 flag_array |= isnull(arr1) flag_array |= isnull(arr2) @@ -261,7 +269,7 @@ def stack(arrays, axis=0): def _ignore_warnings_if(condition): if condition: with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") yield else: yield @@ -271,17 +279,17 @@ def _create_nan_agg_method(name, coerce_strings=False): from . import nanops def f(values, axis=None, skipna=None, **kwargs): - if kwargs.pop('out', None) is not None: - raise TypeError('`out` is not valid for {}'.format(name)) + if kwargs.pop("out", None) is not None: + raise TypeError("`out` is not valid for {}".format(name)) values = asarray(values) - if coerce_strings and values.dtype.kind in 'SU': + if coerce_strings and values.dtype.kind in "SU": values = values.astype(object) func = None - if skipna or (skipna is None and values.dtype.kind in 'cfO'): - nanname = 'nan' + name + if skipna or (skipna is None and values.dtype.kind in "cfO"): + nanname = "nan" + name func = getattr(nanops, nanname) else: func = _dask_or_eager_func(name) @@ -292,14 +300,15 @@ def f(values, axis=None, skipna=None, **kwargs): if isinstance(values, dask_array_type): try: # dask/dask#3133 dask sometimes needs dtype argument # if func does not accept dtype, then raises TypeError - return func(values, axis=axis, dtype=values.dtype, - **kwargs) + return func(values, axis=axis, dtype=values.dtype, **kwargs) except (AttributeError, TypeError): - msg = '%s is not yet implemented on dask arrays' % name + msg = "%s is not yet implemented on dask arrays" % name else: - msg = ('%s is not available with skipna=False with the ' - 'installed version of numpy; upgrade to numpy 1.12 ' - 'or newer to use skipna=True or skipna=None' % name) + msg = ( + "%s is not available with skipna=False with the " + "installed version of numpy; upgrade to numpy 1.12 " + "or newer to use skipna=True or skipna=None" % name + ) raise NotImplementedError(msg) f.__name__ = name @@ -308,29 +317,29 @@ def f(values, axis=None, skipna=None, **kwargs): # Attributes `numeric_only`, `available_min_count` is used for docs. # See ops.inject_reduce_methods -argmax = _create_nan_agg_method('argmax', coerce_strings=True) -argmin = _create_nan_agg_method('argmin', coerce_strings=True) -max = _create_nan_agg_method('max', coerce_strings=True) -min = _create_nan_agg_method('min', coerce_strings=True) -sum = _create_nan_agg_method('sum') +argmax = _create_nan_agg_method("argmax", coerce_strings=True) +argmin = _create_nan_agg_method("argmin", coerce_strings=True) +max = _create_nan_agg_method("max", coerce_strings=True) +min = _create_nan_agg_method("min", coerce_strings=True) +sum = _create_nan_agg_method("sum") sum.numeric_only = True sum.available_min_count = True -std = _create_nan_agg_method('std') +std = _create_nan_agg_method("std") std.numeric_only = True -var = _create_nan_agg_method('var') +var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method('median') +median = _create_nan_agg_method("median") median.numeric_only = True -prod = _create_nan_agg_method('prod') +prod = _create_nan_agg_method("prod") prod.numeric_only = True sum.available_min_count = True -cumprod_1d = _create_nan_agg_method('cumprod') +cumprod_1d = _create_nan_agg_method("cumprod") cumprod_1d.numeric_only = True -cumsum_1d = _create_nan_agg_method('cumsum') +cumsum_1d = _create_nan_agg_method("cumsum") cumsum_1d.numeric_only = True -_mean = _create_nan_agg_method('mean') +_mean = _create_nan_agg_method("mean") def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): @@ -355,10 +364,10 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): offset = array.min() array = array - offset - if not hasattr(array, 'dtype'): # scalar is converted to 0d-array + if not hasattr(array, "dtype"): # scalar is converted to 0d-array array = np.array(array) - if array.dtype.kind in 'O': + if array.dtype.kind in "O": # possibly convert object array containing datetime.timedelta array = np.asarray(pd.Series(array.ravel())).reshape(array.shape) @@ -366,12 +375,12 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): array = array / np.timedelta64(1, datetime_unit) # convert np.NaT to np.nan - if array.dtype.kind in 'mM': + if array.dtype.kind in "mM": return np.where(isnull(array), np.nan, array.astype(dtype)) return array.astype(dtype) -def _to_pytimedelta(array, unit='us'): +def _to_pytimedelta(array, unit="us"): index = pd.TimedeltaIndex(array.ravel(), unit=unit) return index.to_pytimedelta().reshape(array.shape) @@ -382,23 +391,27 @@ def mean(array, axis=None, skipna=None, **kwargs): from .common import _contains_cftime_datetimes array = asarray(array) - if array.dtype.kind in 'Mm': + if array.dtype.kind in "Mm": offset = min(array) # xarray always uses np.datetime64[ns] for np.datetime64 data - dtype = 'timedelta64[ns]' - return _mean(datetime_to_numeric(array, offset), axis=axis, - skipna=skipna, **kwargs).astype(dtype) + offset + dtype = "timedelta64[ns]" + return ( + _mean( + datetime_to_numeric(array, offset), axis=axis, skipna=skipna, **kwargs + ).astype(dtype) + + offset + ) elif _contains_cftime_datetimes(array): if isinstance(array, dask_array_type): raise NotImplementedError( - 'Computing the mean of an array containing ' - 'cftime.datetime objects is not yet implemented on ' - 'dask arrays.') + "Computing the mean of an array containing " + "cftime.datetime objects is not yet implemented on " + "dask arrays." + ) offset = min(array) - timedeltas = datetime_to_numeric(array, offset, datetime_unit='us') - mean_timedeltas = _mean(timedeltas, axis=axis, skipna=skipna, - **kwargs) - return _to_pytimedelta(mean_timedeltas, unit='us') + offset + timedeltas = datetime_to_numeric(array, offset, datetime_unit="us") + mean_timedeltas = _mean(timedeltas, axis=axis, skipna=skipna, **kwargs) + return _to_pytimedelta(mean_timedeltas, unit="us") + offset else: return _mean(array, axis=axis, skipna=skipna, **kwargs) @@ -431,13 +444,14 @@ def cumsum(array, axis=None, **kwargs): _fail_on_dask_array_input_skipna = partial( fail_on_dask_array_input, - msg='%r with skipna=True is not yet implemented on dask arrays') + msg="%r with skipna=True is not yet implemented on dask arrays", +) def first(values, axis, skipna=None): """Return the first non-NA elements in this array along the given axis """ - if (skipna or skipna is None) and values.dtype.kind not in 'iSU': + if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN _fail_on_dask_array_input_skipna(values) return nanfirst(values, axis) @@ -447,7 +461,7 @@ def first(values, axis, skipna=None): def last(values, axis, skipna=None): """Return the last non-NA elements in this array along the given axis """ - if (skipna or skipna is None) and values.dtype.kind not in 'iSU': + if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN _fail_on_dask_array_input_skipna(values) return nanlast(values, axis) @@ -460,8 +474,6 @@ def rolling_window(array, axis, window, center, fill_value): The rolling dimension will be placed at the last dimension. """ if isinstance(array, dask_array_type): - return dask_array_ops.rolling_window( - array, axis, window, center, fill_value) + return dask_array_ops.rolling_window(array, axis, window, center, fill_value) else: # np.ndarray - return nputils.rolling_window( - array, axis, window, center, fill_value) + return nputils.rolling_window(array, axis, window, center, fill_value) diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index cb34e87f88d..302a7fb2ec6 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -25,7 +25,7 @@ def __get__(self, obj, cls): # __getattr__ on data object will swallow any AttributeErrors # raised when initializing the accessor, so we need to raise as # something else (GH933): - raise RuntimeError('error initializing %r accessor.' % self._name) + raise RuntimeError("error initializing %r accessor." % self._name) # Replace the property with the accessor object. Inspired by: # http://www.pydanny.com/cached-property.html # We need to use object.__setattr__ because we overwrite __setattr__ on @@ -38,13 +38,15 @@ def _register_accessor(name, cls): def decorator(accessor): if hasattr(cls, name): warnings.warn( - 'registration of accessor %r under name %r for type %r is ' - 'overriding a preexisting attribute with the same name.' + "registration of accessor %r under name %r for type %r is " + "overriding a preexisting attribute with the same name." % (accessor, name, cls), AccessorRegistrationWarning, - stacklevel=2) + stacklevel=2, + ) setattr(cls, name, _CachedAccessor(name, accessor)) return accessor + return decorator diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 3ddffec8e5e..6a5f1ede632 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -25,20 +25,20 @@ def pretty_print(x, numchars): ellipses as necessary """ s = maybe_truncate(x, numchars) - return s + ' ' * max(numchars - len(s), 0) + return s + " " * max(numchars - len(s), 0) def maybe_truncate(obj, maxlen=500): s = str(obj) if len(s) > maxlen: - s = s[:(maxlen - 3)] + '...' + s = s[: (maxlen - 3)] + "..." return s -def wrap_indent(text, start='', length=None): +def wrap_indent(text, start="", length=None): if length is None: length = len(start) - indent = '\n' + ' ' * length + indent = "\n" + " " * length return start + indent.join(x for x in text.splitlines()) @@ -47,9 +47,11 @@ def _get_indexer_at_least_n_items(shape, n_desired, from_end): cum_items = np.cumprod(shape[::-1]) n_steps = np.argmax(cum_items >= n_desired) stop = int(np.ceil(float(n_desired) / np.r_[1, cum_items][n_steps])) - indexer = (((-1 if from_end else 0),) * (len(shape) - 1 - n_steps) + - ((slice(-stop, None) if from_end else slice(stop)),) + - (slice(None),) * n_steps) + indexer = ( + ((-1 if from_end else 0),) * (len(shape) - 1 - n_steps) + + ((slice(-stop, None) if from_end else slice(stop)),) + + (slice(None),) * n_steps + ) return indexer @@ -60,15 +62,14 @@ def first_n_items(array, n_desired): # could be very expensive (e.g. if it's only available over DAP), so go out # of our way to get them in a single call to __getitem__ using only slices. if n_desired < 1: - raise ValueError('must request at least one item') + raise ValueError("must request at least one item") if array.size == 0: # work around for https://github.com/numpy/numpy/issues/5195 return [] if n_desired < array.size: - indexer = _get_indexer_at_least_n_items(array.shape, n_desired, - from_end=False) + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=False) array = array[indexer] return np.asarray(array).flat[:n_desired] @@ -83,8 +84,7 @@ def last_n_items(array, n_desired): return [] if n_desired < array.size: - indexer = _get_indexer_at_least_n_items(array.shape, n_desired, - from_end=True) + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=True) array = array[indexer] return np.asarray(array).flat[-n_desired:] @@ -113,24 +113,24 @@ def format_timestamp(t): # catch NaT and others that don't split nicely return datetime_str else: - if time_str == '00:00:00': + if time_str == "00:00:00": return date_str else: - return '{}T{}'.format(date_str, time_str) + return "{}T{}".format(date_str, time_str) def format_timedelta(t, timedelta_format=None): """Cast given object to a Timestamp and return a nicely formatted string""" timedelta_str = str(pd.Timedelta(t)) try: - days_str, time_str = timedelta_str.split(' days ') + days_str, time_str = timedelta_str.split(" days ") except ValueError: # catch NaT and others that don't split nicely return timedelta_str else: - if timedelta_format == 'date': - return days_str + ' days' - elif timedelta_format == 'time': + if timedelta_format == "date": + return days_str + " days" + elif timedelta_format == "time": return time_str else: return timedelta_str @@ -145,7 +145,7 @@ def format_item(x, timedelta_format=None, quote_strings=True): elif isinstance(x, (str, bytes)): return repr(x) if quote_strings else x elif isinstance(x, (float, np.float)): - return '{:.4}'.format(x) + return "{:.4}".format(x) else: return str(x) @@ -153,18 +153,16 @@ def format_item(x, timedelta_format=None, quote_strings=True): def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" x = np.asarray(x) - timedelta_format = 'datetime' + timedelta_format = "datetime" if np.issubdtype(x.dtype, np.timedelta64): - x = np.asarray(x, dtype='timedelta64[ns]') - day_part = (x[~pd.isnull(x)] - .astype('timedelta64[D]') - .astype('timedelta64[ns]')) + x = np.asarray(x, dtype="timedelta64[ns]") + day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part - day_needed = day_part != np.timedelta64(0, 'ns') + day_needed = day_part != np.timedelta64(0, "ns") if np.logical_not(day_needed).all(): - timedelta_format = 'time' + timedelta_format = "time" elif np.logical_not(time_needed).all(): - timedelta_format = 'date' + timedelta_format = "date" formatted = [format_item(xi, timedelta_format) for xi in x] return formatted @@ -176,70 +174,78 @@ def format_array_flat(array, max_width): """ # every item will take up at least two characters, but we always want to # print at least first and last items - max_possibly_relevant = min(max(array.size, 1), - max(int(np.ceil(max_width / 2.)), 2)) + max_possibly_relevant = min( + max(array.size, 1), max(int(np.ceil(max_width / 2.0)), 2) + ) relevant_front_items = format_items( - first_n_items(array, (max_possibly_relevant + 1) // 2)) - relevant_back_items = format_items( - last_n_items(array, max_possibly_relevant // 2)) + first_n_items(array, (max_possibly_relevant + 1) // 2) + ) + relevant_back_items = format_items(last_n_items(array, max_possibly_relevant // 2)) # interleave relevant front and back items: # [a, b, c] and [y, z] -> [a, z, b, y, c] - relevant_items = sum(zip_longest(relevant_front_items, - reversed(relevant_back_items)), - ())[:max_possibly_relevant] + relevant_items = sum( + zip_longest(relevant_front_items, reversed(relevant_back_items)), () + )[:max_possibly_relevant] cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1 - if (array.size > 2) and ((max_possibly_relevant < array.size) or - (cum_len > max_width).any()): - padding = ' ... ' - count = min(array.size, - max(np.argmax(cum_len + len(padding) - 1 > max_width), 2)) + if (array.size > 2) and ( + (max_possibly_relevant < array.size) or (cum_len > max_width).any() + ): + padding = " ... " + count = min( + array.size, max(np.argmax(cum_len + len(padding) - 1 > max_width), 2) + ) else: count = array.size - padding = '' if (count <= 1) else ' ' + padding = "" if (count <= 1) else " " num_front = (count + 1) // 2 num_back = count - num_front # note that num_back is 0 <--> array.size is 0 or 1 # <--> relevant_back_items is [] - pprint_str = (' '.join(relevant_front_items[:num_front]) + - padding + - ' '.join(relevant_back_items[-num_back:])) + pprint_str = ( + " ".join(relevant_front_items[:num_front]) + + padding + + " ".join(relevant_back_items[-num_back:]) + ) return pprint_str -def summarize_variable(name, var, col_width, show_values=True, - marker=' ', max_width=None): +def summarize_variable( + name, var, col_width, show_values=True, marker=" ", max_width=None +): if max_width is None: - max_width = OPTIONS['display_width'] - first_col = pretty_print(' {} {} '.format(marker, name), col_width) + max_width = OPTIONS["display_width"] + first_col = pretty_print(" {} {} ".format(marker, name), col_width) if var.dims: - dims_str = '({}) '.format(', '.join(map(str, var.dims))) + dims_str = "({}) ".format(", ".join(map(str, var.dims))) else: - dims_str = '' - front_str = '{}{}{} '.format(first_col, dims_str, var.dtype) + dims_str = "" + front_str = "{}{}{} ".format(first_col, dims_str, var.dtype) if show_values: values_str = format_array_flat(var, max_width - len(front_str)) elif isinstance(var._data, dask_array_type): values_str = short_dask_repr(var, show_dtype=False) else: - values_str = '...' + values_str = "..." return front_str + values_str def _summarize_coord_multiindex(coord, col_width, marker): - first_col = pretty_print(' {} {} '.format( - marker, coord.name), col_width) - return '{}({}) MultiIndex'.format(first_col, str(coord.dims[0])) + first_col = pretty_print(" {} {} ".format(marker, coord.name), col_width) + return "{}({}) MultiIndex".format(first_col, str(coord.dims[0])) -def _summarize_coord_levels(coord, col_width, marker='-'): - return '\n'.join( - [summarize_variable(lname, - coord.get_level_variable(lname), - col_width, marker=marker) - for lname in coord.level_names]) +def _summarize_coord_levels(coord, col_width, marker="-"): + return "\n".join( + [ + summarize_variable( + lname, coord.get_level_variable(lname), col_width, marker=marker + ) + for lname in coord.level_names + ] + ) def summarize_datavar(name, var, col_width): @@ -250,31 +256,32 @@ def summarize_datavar(name, var, col_width): def summarize_coord(name, var, col_width): is_index = name in var.dims show_values = var._in_memory - marker = '*' if is_index else ' ' + marker = "*" if is_index else " " if is_index: coord = var.variable.to_index_variable() if coord.level_names is not None: - return '\n'.join( - [_summarize_coord_multiindex(coord, col_width, marker), - _summarize_coord_levels(coord, col_width)]) - return summarize_variable( - name, var.variable, col_width, show_values, marker) + return "\n".join( + [ + _summarize_coord_multiindex(coord, col_width, marker), + _summarize_coord_levels(coord, col_width), + ] + ) + return summarize_variable(name, var.variable, col_width, show_values, marker) def summarize_attr(key, value, col_width=None): """Summary for __repr__ - use ``X.attrs[key]`` for full value.""" # Indent key and add ':', then right-pad if col_width is not None - k_str = ' {}:'.format(key) + k_str = " {}:".format(key) if col_width is not None: k_str = pretty_print(k_str, col_width) # Replace tabs and newlines, so we print on one line in known width - v_str = str(value).replace('\t', '\\t').replace('\n', '\\n') + v_str = str(value).replace("\t", "\\t").replace("\n", "\\n") # Finally, truncate to the desired display width - return maybe_truncate('{} {}'.format(k_str, v_str), - OPTIONS['display_width']) + return maybe_truncate("{} {}".format(k_str, v_str), OPTIONS["display_width"]) -EMPTY_REPR = ' *empty*' +EMPTY_REPR = " *empty*" def _get_col_items(mapping): @@ -286,7 +293,7 @@ def _get_col_items(mapping): col_items = [] for k, v in mapping.items(): col_items.append(k) - var = getattr(v, 'variable', v) + var = getattr(v, "variable", v) if isinstance(var, IndexVariable): level_names = var.to_index_variable().level_names if level_names is not None: @@ -295,8 +302,7 @@ def _get_col_items(mapping): def _calculate_col_width(col_items): - max_name_length = (max(len(str(s)) for s in col_items) - if col_items else 0) + max_name_length = max(len(str(s)) for s in col_items) if col_items else 0 col_width = max(max_name_length, 7) + 6 return col_width @@ -304,46 +310,49 @@ def _calculate_col_width(col_items): def _mapping_repr(mapping, title, summarizer, col_width=None): if col_width is None: col_width = _calculate_col_width(mapping) - summary = ['{}:'.format(title)] + summary = ["{}:".format(title)] if mapping: summary += [summarizer(k, v, col_width) for k, v in mapping.items()] else: summary += [EMPTY_REPR] - return '\n'.join(summary) + return "\n".join(summary) -data_vars_repr = functools.partial(_mapping_repr, title='Data variables', - summarizer=summarize_datavar) +data_vars_repr = functools.partial( + _mapping_repr, title="Data variables", summarizer=summarize_datavar +) -attrs_repr = functools.partial(_mapping_repr, title='Attributes', - summarizer=summarize_attr) +attrs_repr = functools.partial( + _mapping_repr, title="Attributes", summarizer=summarize_attr +) def coords_repr(coords, col_width=None): if col_width is None: col_width = _calculate_col_width(_get_col_items(coords)) - return _mapping_repr(coords, title='Coordinates', - summarizer=summarize_coord, col_width=col_width) + return _mapping_repr( + coords, title="Coordinates", summarizer=summarize_coord, col_width=col_width + ) def indexes_repr(indexes): summary = [] for k, v in indexes.items(): - summary.append(wrap_indent(repr(v), '{}: '.format(k))) - return '\n'.join(summary) + summary.append(wrap_indent(repr(v), "{}: ".format(k))) + return "\n".join(summary) def dim_summary(obj): - elements = ['{}: {}'.format(k, v) for k, v in obj.sizes.items()] - return ', '.join(elements) + elements = ["{}: {}".format(k, v) for k, v in obj.sizes.items()] + return ", ".join(elements) def unindexed_dims_repr(dims, coords): unindexed_dims = [d for d in dims if d not in coords] if unindexed_dims: - dims_str = ', '.join('{}'.format(d) for d in unindexed_dims) - return 'Dimensions without coordinates: ' + dims_str + dims_str = ", ".join("{}".format(d) for d in unindexed_dims) + return "Dimensions without coordinates: " + dims_str else: return None @@ -358,23 +367,19 @@ def set_numpy_options(*args, **kwargs): def short_array_repr(array): - if not hasattr(array, '__array_function__'): + if not hasattr(array, "__array_function__"): array = np.asarray(array) # default to lower precision so a full (abbreviated) line can fit on # one line with the default display_width - options = { - 'precision': 6, - 'linewidth': OPTIONS['display_width'], - 'threshold': 200, - } + options = {"precision": 6, "linewidth": OPTIONS["display_width"], "threshold": 200} if array.ndim < 3: edgeitems = 3 elif array.ndim == 3: edgeitems = 2 else: edgeitems = 1 - options['edgeitems'] = edgeitems + options["edgeitems"] = edgeitems with set_numpy_options(**options): return repr(array) @@ -386,37 +391,35 @@ def short_dask_repr(array, show_dtype=True): """ chunksize = tuple(c[0] for c in array.chunks) if show_dtype: - return 'dask.array'.format( - array.shape, array.dtype, chunksize) + return "dask.array".format( + array.shape, array.dtype, chunksize + ) else: - return 'dask.array'.format( - array.shape, chunksize) + return "dask.array".format(array.shape, chunksize) def short_data_repr(array): - if isinstance(getattr(array, 'variable', array)._data, dask_array_type): + if isinstance(getattr(array, "variable", array)._data, dask_array_type): return short_dask_repr(array) elif array._in_memory or array.size < 1e5: return short_array_repr(array.data) else: - return '[{} values with dtype={}]'.format(array.size, array.dtype) + return "[{} values with dtype={}]".format(array.size, array.dtype) def array_repr(arr): # used for DataArray, Variable and IndexVariable - if hasattr(arr, 'name') and arr.name is not None: - name_str = '{!r} '.format(arr.name) + if hasattr(arr, "name") and arr.name is not None: + name_str = "{!r} ".format(arr.name) else: - name_str = '' + name_str = "" summary = [ - ''.format( - type(arr).__name__, name_str, dim_summary(arr) - ), - short_data_repr(arr) + "".format(type(arr).__name__, name_str, dim_summary(arr)), + short_data_repr(arr), ] - if hasattr(arr, 'coords'): + if hasattr(arr, "coords"): if arr.coords: summary.append(repr(arr.coords)) @@ -427,16 +430,16 @@ def array_repr(arr): if arr.attrs: summary.append(attrs_repr(arr.attrs)) - return '\n'.join(summary) + return "\n".join(summary) def dataset_repr(ds): - summary = [''.format(type(ds).__name__)] + summary = ["".format(type(ds).__name__)] col_width = _calculate_col_width(_get_col_items(ds.variables)) - dims_start = pretty_print('Dimensions:', col_width) - summary.append('{}({})'.format(dims_start, dim_summary(ds))) + dims_start = pretty_print("Dimensions:", col_width) + summary.append("{}({})".format(dims_start, dim_summary(ds))) if ds.coords: summary.append(coords_repr(ds.coords, col_width=col_width)) @@ -450,20 +453,19 @@ def dataset_repr(ds): if ds.attrs: summary.append(attrs_repr(ds.attrs)) - return '\n'.join(summary) + return "\n".join(summary) def diff_dim_summary(a, b): if a.dims != b.dims: return "Differing dimensions:\n ({}) != ({})".format( - dim_summary(a), dim_summary(b)) + dim_summary(a), dim_summary(b) + ) else: return "" -def _diff_mapping_repr(a_mapping, b_mapping, compat, - title, summarizer, col_width=None): - +def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): def extra_items_repr(extra_keys, mapping, ab_side): extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] if extra_repr: @@ -490,22 +492,25 @@ def extra_items_repr(extra_keys, mapping, ab_side): is_variable = False if not compatible: - temp = [summarizer(k, vars[k], col_width) - for vars in (a_mapping, b_mapping)] + temp = [ + summarizer(k, vars[k], col_width) for vars in (a_mapping, b_mapping) + ] - if compat == 'identical' and is_variable: + if compat == "identical" and is_variable: attrs_summary = [] for m in (a_mapping, b_mapping): - attr_s = "\n".join([summarize_attr(ak, av) - for ak, av in m[k].attrs.items()]) + attr_s = "\n".join( + [summarize_attr(ak, av) for ak, av in m[k].attrs.items()] + ) attrs_summary.append(attr_s) - temp = ["\n".join([var_s, attr_s]) if attr_s else var_s - for var_s, attr_s in zip(temp, attrs_summary)] + temp = [ + "\n".join([var_s, attr_s]) if attr_s else var_s + for var_s, attr_s in zip(temp, attrs_summary) + ] - diff_items += [ab_side + s[1:] - for ab_side, s in zip(('L', 'R'), temp)] + diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)] if diff_items: summary += ["Differing {}:".format(title.lower())] + diff_items @@ -516,19 +521,19 @@ def extra_items_repr(extra_keys, mapping, ab_side): return "\n".join(summary) -diff_coords_repr = functools.partial(_diff_mapping_repr, - title="Coordinates", - summarizer=summarize_coord) +diff_coords_repr = functools.partial( + _diff_mapping_repr, title="Coordinates", summarizer=summarize_coord +) -diff_data_vars_repr = functools.partial(_diff_mapping_repr, - title="Data variables", - summarizer=summarize_datavar) +diff_data_vars_repr = functools.partial( + _diff_mapping_repr, title="Data variables", summarizer=summarize_datavar +) -diff_attrs_repr = functools.partial(_diff_mapping_repr, - title="Attributes", - summarizer=summarize_attr) +diff_attrs_repr = functools.partial( + _diff_mapping_repr, title="Attributes", summarizer=summarize_attr +) def _compat_to_str(compat): @@ -540,43 +545,52 @@ def _compat_to_str(compat): def diff_array_repr(a, b, compat): # used for DataArray, Variable and IndexVariable - summary = ["Left and right {} objects are not {}" - .format(type(a).__name__, _compat_to_str(compat))] + summary = [ + "Left and right {} objects are not {}".format( + type(a).__name__, _compat_to_str(compat) + ) + ] summary.append(diff_dim_summary(a, b)) if not array_equiv(a.data, b.data): - temp = [wrap_indent(short_array_repr(obj), start=' ') - for obj in (a, b)] - diff_data_repr = [ab_side + "\n" + ab_data_repr - for ab_side, ab_data_repr in zip(('L', 'R'), temp)] + temp = [wrap_indent(short_array_repr(obj), start=" ") for obj in (a, b)] + diff_data_repr = [ + ab_side + "\n" + ab_data_repr + for ab_side, ab_data_repr in zip(("L", "R"), temp) + ] summary += ["Differing values:"] + diff_data_repr - if hasattr(a, 'coords'): + if hasattr(a, "coords"): col_width = _calculate_col_width(set(a.coords) | set(b.coords)) - summary.append(diff_coords_repr(a.coords, b.coords, compat, - col_width=col_width)) + summary.append( + diff_coords_repr(a.coords, b.coords, compat, col_width=col_width) + ) - if compat == 'identical': + if compat == "identical": summary.append(diff_attrs_repr(a.attrs, b.attrs, compat)) return "\n".join(summary) def diff_dataset_repr(a, b, compat): - summary = ["Left and right {} objects are not {}" - .format(type(a).__name__, _compat_to_str(compat))] + summary = [ + "Left and right {} objects are not {}".format( + type(a).__name__, _compat_to_str(compat) + ) + ] col_width = _calculate_col_width( - set(_get_col_items(a.variables) + _get_col_items(b.variables))) + set(_get_col_items(a.variables) + _get_col_items(b.variables)) + ) summary.append(diff_dim_summary(a, b)) - summary.append(diff_coords_repr(a.coords, b.coords, compat, - col_width=col_width)) - summary.append(diff_data_vars_repr(a.data_vars, b.data_vars, compat, - col_width=col_width)) + summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)) + summary.append( + diff_data_vars_repr(a.data_vars, b.data_vars, compat, col_width=col_width) + ) - if compat == 'identical': + if compat == "identical": summary.append(diff_attrs_repr(a.attrs, b.attrs, compat)) return "\n".join(summary) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2be0857a4d3..a0d260c3f33 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -45,6 +45,7 @@ def unique_value_groups(ar, sort=True): def _dummy_copy(xarray_obj): from .dataset import Dataset from .dataarray import DataArray + if isinstance(xarray_obj, Dataset): res = Dataset( { @@ -56,7 +57,7 @@ def _dummy_copy(xarray_obj): for k, v in xarray_obj.coords.items() if k not in xarray_obj.dims }, - xarray_obj.attrs + xarray_obj.attrs, ) elif isinstance(xarray_obj, DataArray): res = DataArray( @@ -68,7 +69,7 @@ def _dummy_copy(xarray_obj): }, dims=[], name=xarray_obj.name, - attrs=xarray_obj.attrs + attrs=xarray_obj.attrs, ) else: # pragma: no cover raise AssertionError @@ -86,10 +87,13 @@ def _consolidate_slices(slices): last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): - raise ValueError('list element is not a slice: %r' % slice_) - if (result and last_slice.stop == slice_.start and - _is_one_or_none(last_slice.step) and - _is_one_or_none(slice_.step)): + raise ValueError("list element is not a slice: %r" % slice_) + if ( + result + and last_slice.stop == slice_.start + and _is_one_or_none(last_slice.step) + and _is_one_or_none(slice_.step) + ): last_slice = slice(last_slice.start, slice_.stop, slice_.step) result[-1] = last_slice else: @@ -142,7 +146,7 @@ def _ensure_1d(group, obj): if group.ndim != 1: # try to stack the dims of the group into a single dim orig_dims = group.dims - stacked_dim = 'stacked_' + '_'.join(orig_dims) + stacked_dim = "stacked_" + "_".join(orig_dims) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] # the copy is necessary here, otherwise read only array raises error @@ -206,8 +210,16 @@ class GroupBy(SupportsArithmetic): DataArray.groupby """ - def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, - restore_coord_dims=None, cut_kwargs={}): + def __init__( + self, + obj, + group, + squeeze=False, + grouper=None, + bins=None, + restore_coord_dims=None, + cut_kwargs={}, + ): """Create a GroupBy object Parameters @@ -239,8 +251,10 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, if not isinstance(group, (DataArray, IndexVariable)): if not hashable(group): - raise TypeError('`group` must be an xarray.DataArray or the ' - 'name of an xarray variable or dimension') + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension" + ) group = obj[group] if len(group) == 0: raise ValueError("{} must not be empty".format(group.name)) @@ -249,23 +263,25 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) - if getattr(group, 'name', None) is None: - raise ValueError('`group` must have a name') + if getattr(group, "name", None) is None: + raise ValueError("`group` must have a name") group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) group_dim, = group.dims expected_size = obj.sizes[group_dim] if group.size != expected_size: - raise ValueError('the group variable\'s length does not ' - 'match the length of this variable along its ' - 'dimension') + raise ValueError( + "the group variable's length does not " + "match the length of this variable along its " + "dimension" + ) full_index = None if bins is not None: binned = pd.cut(group.values, bins, **cut_kwargs) - new_dim_name = group.name + '_bins' + new_dim_name = group.name + "_bins" group = DataArray(binned, group.coords, name=new_dim_name) full_index = binned.categories @@ -273,13 +289,12 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, index = safe_cast_to_index(group) if not index.is_monotonic: # TODO: sort instead of raising an error - raise ValueError('index must be monotonic for resampling') - full_index, first_items = self._get_index_and_items( - index, grouper) + raise ValueError("index must be monotonic for resampling") + full_index, first_items = self._get_index_and_items(index, grouper) sbins = first_items.values.astype(np.int64) - group_indices = ([slice(i, j) - for i, j in zip(sbins[:-1], sbins[1:])] + - [slice(sbins[-1], None)]) + group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ + slice(sbins[-1], None) + ] unique_coord = IndexVariable(group.name, first_items.index) elif group.dims == (group.name,) and _unique_and_monotonic(group): # no need to factorize @@ -292,17 +307,23 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, else: # look through group to find the unique values unique_values, group_indices = unique_value_groups( - safe_cast_to_index(group), sort=(bins is None)) + safe_cast_to_index(group), sort=(bins is None) + ) unique_coord = IndexVariable(group.name, unique_values) - if isinstance(obj, DataArray) \ - and restore_coord_dims is None \ - and any(obj[c].ndim > 1 for c in obj.coords): - warnings.warn('This DataArray contains multi-dimensional ' - 'coordinates. In the future, the dimension order ' - 'of these coordinates will be restored as well ' - 'unless you specify restore_coord_dims=False.', - FutureWarning, stacklevel=2) + if ( + isinstance(obj, DataArray) + and restore_coord_dims is None + and any(obj[c].ndim > 1 for c in obj.coords) + ): + warnings.warn( + "This DataArray contains multi-dimensional " + "coordinates. In the future, the dimension order " + "of these coordinates will be restored as well " + "unless you specify restore_coord_dims=False.", + FutureWarning, + stacklevel=2, + ) restore_coord_dims = False # specification for the groupby operation @@ -323,8 +344,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, def groups(self): # provided to mimic pandas.groupby if self._groups is None: - self._groups = dict(zip(self._unique_coord.values, - self._group_indices)) + self._groups = dict(zip(self._unique_coord.values, self._group_indices)) return self._groups def __len__(self): @@ -335,6 +355,7 @@ def __iter__(self): def _get_index_and_items(self, index, grouper): from .resample_cftime import CFTimeGrouper + s = pd.Series(np.arange(index.size), index) if isinstance(grouper, CFTimeGrouper): first_items = grouper.first_items(index) @@ -371,6 +392,7 @@ def func(self, other): applied = self._yield_binary_applied(g, other) combined = self._combine(applied) return combined + return func def _yield_binary_applied(self, func, other): @@ -380,15 +402,18 @@ def _yield_binary_applied(self, func, other): try: other_sel = other.sel(**{self._group.name: group_value}) except AttributeError: - raise TypeError('GroupBy objects only support binary ops ' - 'when the other argument is a Dataset or ' - 'DataArray') + raise TypeError( + "GroupBy objects only support binary ops " + "when the other argument is a Dataset or " + "DataArray" + ) except (KeyError, ValueError): if self._group.name not in other.dims: - raise ValueError('incompatible dimensions for a grouped ' - 'binary operation: the group variable %r ' - 'is not a dimension on the other argument' - % self._group.name) + raise ValueError( + "incompatible dimensions for a grouped " + "binary operation: the group variable %r " + "is not a dimension on the other argument" % self._group.name + ) if dummy is None: dummy = _dummy_copy(other) other_sel = dummy @@ -400,8 +425,7 @@ def _maybe_restore_empty_groups(self, combined): """Our index contained empty groups (e.g., from a resampling). If we reduced on that dimension, we want to restore the full index. """ - if (self._full_index is not None and - self._group.name in combined.dims): + if self._full_index is not None and self._group.name in combined.dims: indexers = {self._group.name: self._full_index} combined = combined.reindex(**indexers) return combined @@ -469,8 +493,9 @@ def _first_or_last(self, op, skipna, keep_attrs): return self._obj if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - return self.reduce(op, self._group_dim, skipna=skipna, - keep_attrs=keep_attrs, allow_lazy=True) + return self.reduce( + op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs, allow_lazy=True + ) def first(self, skipna=None, keep_attrs=None): """Return the first element of each group along the group dimension @@ -535,8 +560,7 @@ def lookup_order(dimension): return axis new_order = sorted(stacked.dims, key=lookup_order) - return stacked.transpose( - *new_order, transpose_coords=self._restore_coord_dims) + return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims) def apply(self, func, shortcut=False, args=(), **kwargs): """Apply a function over each array in the group and concatenate them @@ -582,8 +606,7 @@ def apply(self, func, shortcut=False, args=(), **kwargs): grouped = self._iter_grouped_shortcut() else: grouped = self._iter_grouped() - applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) - for arr in grouped) + applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) return self._combine(applied, shortcut=shortcut) def _combine(self, applied, restore_coord_dims=False, shortcut=False): @@ -608,7 +631,7 @@ def _combine(self, applied, restore_coord_dims=False, shortcut=False): combined = self._maybe_unstack(combined) return combined - def quantile(self, q, dim=None, interpolation='linear', keep_attrs=None): + def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): """Compute the qth quantile over each array in the groups and concatenate them together into a new array. @@ -656,18 +679,26 @@ def quantile(self, q, dim=None, interpolation='linear', keep_attrs=None): "grouped dimension in a future version of xarray. To " "silence this warning, pass dim=xarray.ALL_DIMS " "explicitly.", - FutureWarning, stacklevel=2) - - out = self.apply(self._obj.__class__.quantile, shortcut=False, - q=q, dim=dim, interpolation=interpolation, - keep_attrs=keep_attrs) + FutureWarning, + stacklevel=2, + ) + + out = self.apply( + self._obj.__class__.quantile, + shortcut=False, + q=q, + dim=dim, + interpolation=interpolation, + keep_attrs=keep_attrs, + ) if np.asarray(q, dtype=np.float64).ndim == 0: - out = out.drop('quantile') + out = out.drop("quantile") return out - def reduce(self, func, dim=None, axis=None, keep_attrs=None, - shortcut=True, **kwargs): + def reduce( + self, func, dim=None, axis=None, keep_attrs=None, shortcut=True, **kwargs + ): """Reduce the items in this group by applying `func` along some dimension(s). @@ -706,13 +737,16 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=None, "grouped dimension in a future version of xarray. To " "silence this warning, pass dim=xarray.ALL_DIMS " "explicitly.", - FutureWarning, stacklevel=2) + FutureWarning, + stacklevel=2, + ) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) def reduce_array(ar): return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) + return self.apply(reduce_array, shortcut=shortcut) # TODO remove the following class method and DEFAULT_DIMS after the @@ -720,19 +754,42 @@ def reduce_array(ar): @classmethod def _reduce_method(cls, func, include_skipna, numeric_only): if include_skipna: - def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, skipna=None, - keep_attrs=None, **kwargs): - return self.reduce(func, dim, axis, keep_attrs=keep_attrs, - skipna=skipna, allow_lazy=True, **kwargs) + + def wrapped_func( + self, + dim=DEFAULT_DIMS, + axis=None, + skipna=None, + keep_attrs=None, + **kwargs + ): + return self.reduce( + func, + dim, + axis, + keep_attrs=keep_attrs, + skipna=skipna, + allow_lazy=True, + **kwargs + ) + else: - def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, # type: ignore - keep_attrs=None, **kwargs): - return self.reduce(func, dim, axis, keep_attrs=keep_attrs, - allow_lazy=True, **kwargs) + + def wrapped_func( + self, + dim=DEFAULT_DIMS, + axis=None, # type: ignore + keep_attrs=None, + **kwargs + ): + return self.reduce( + func, dim, axis, keep_attrs=keep_attrs, allow_lazy=True, **kwargs + ) + return wrapped_func -DEFAULT_DIMS = utils.ReprObject('') +DEFAULT_DIMS = utils.ReprObject("") ops.inject_reduce_methods(DataArrayGroupBy) ops.inject_binary_ops(DataArrayGroupBy) @@ -823,7 +880,9 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): "grouped dimension in a future version of xarray. To " "silence this warning, pass dim=xarray.ALL_DIMS " "explicitly.", - FutureWarning, stacklevel=2) + FutureWarning, + stacklevel=2, + ) elif dim is None: dim = self._group_dim @@ -832,6 +891,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): def reduce_dataset(ds): return ds.reduce(func, dim, keep_attrs, **kwargs) + return self.apply(reduce_dataset) # TODO remove the following class method and DEFAULT_DIMS after the @@ -839,17 +899,28 @@ def reduce_dataset(ds): @classmethod def _reduce_method(cls, func, include_skipna, numeric_only): if include_skipna: - def wrapped_func(self, dim=DEFAULT_DIMS, - skipna=None, **kwargs): - return self.reduce(func, dim, - skipna=skipna, numeric_only=numeric_only, - allow_lazy=True, **kwargs) + + def wrapped_func(self, dim=DEFAULT_DIMS, skipna=None, **kwargs): + return self.reduce( + func, + dim, + skipna=skipna, + numeric_only=numeric_only, + allow_lazy=True, + **kwargs + ) + else: - def wrapped_func(self, dim=DEFAULT_DIMS, # type: ignore - **kwargs): - return self.reduce(func, dim, - numeric_only=numeric_only, allow_lazy=True, - **kwargs) + + def wrapped_func( + self, + dim=DEFAULT_DIMS, # type: ignore + **kwargs + ): + return self.reduce( + func, dim, numeric_only=numeric_only, allow_lazy=True, **kwargs + ) + return wrapped_func def assign(self, **kwargs): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 253dbd23164..5917f7c7a2d 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,6 +10,7 @@ class Indexes(collections.abc.Mapping): """Immutable proxy for Dataset or DataArrary indexes.""" + def __init__(self, indexes): """Not for public consumption. @@ -37,9 +38,8 @@ def __repr__(self): def default_indexes( - coords: Mapping[Any, Variable], - dims: Iterable, -) -> 'OrderedDict[Any, pd.Index]': + coords: Mapping[Any, Variable], dims: Iterable +) -> "OrderedDict[Any, pd.Index]": """Default indexes for a Dataset/DataArray. Parameters @@ -54,8 +54,7 @@ def default_indexes( Mapping from indexing keys (levels/dimension names) to indexes used for indexing along that dimension. """ - return OrderedDict((key, coords[key].to_index()) - for key in dims if key in coords) + return OrderedDict((key, coords[key].to_index()) for key in dims if key in coords) def isel_variable_and_index( @@ -71,8 +70,8 @@ def isel_variable_and_index( if len(variable.dims) > 1: raise NotImplementedError( - 'indexing multi-dimensional variable with indexes is not ' - 'supported yet') + "indexing multi-dimensional variable with indexes is not " "supported yet" + ) new_variable = variable.isel(indexers) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index a9ad55e2652..a6edb0a9562 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -39,7 +39,7 @@ def expanded_indexer(key, ndim): else: new_key.append(k) if len(new_key) > ndim: - raise IndexError('too many indices') + raise IndexError("too many indices") new_key.extend((ndim - len(new_key)) * [slice(None)]) return tuple(new_key) @@ -57,8 +57,10 @@ def _sanitize_slice_element(x): if isinstance(x, np.ndarray): if x.ndim != 0: - raise ValueError('cannot use non-scalar arrays in a slice for ' - 'xarray indexing: {}'.format(x)) + raise ValueError( + "cannot use non-scalar arrays in a slice for " + "xarray indexing: {}".format(x) + ) x = x[()] if isinstance(x, np.timedelta64): @@ -88,9 +90,9 @@ def _asarray_tuplesafe(values): def _is_nested_tuple(possible_tuple): - return (isinstance(possible_tuple, tuple) and - any(isinstance(value, (tuple, list, slice)) - for value in possible_tuple)) + return isinstance(possible_tuple, tuple) and any( + isinstance(value, (tuple, list, slice)) for value in possible_tuple + ) def _index_method_kwargs(method, tolerance): @@ -98,9 +100,9 @@ def _index_method_kwargs(method, tolerance): # (tolerance) kwargs = {} if method is not None: - kwargs['method'] = method + kwargs["method"] = method if tolerance is not None: - kwargs['tolerance'] = tolerance + kwargs["tolerance"] = tolerance return kwargs @@ -119,8 +121,7 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): return indexer -def convert_label_indexer(index, label, index_name='', method=None, - tolerance=None): +def convert_label_indexer(index, label, index_name="", method=None, tolerance=None): """Given a pandas.Index and labels (e.g., from __getitem__) for one dimension, return an indexer suitable for indexing an ndarray along that dimension. If `index` is a pandas.MultiIndex and depending on `label`, @@ -131,37 +132,46 @@ def convert_label_indexer(index, label, index_name='', method=None, if isinstance(label, slice): if method is not None or tolerance is not None: raise NotImplementedError( - 'cannot use ``method`` argument if any indexers are ' - 'slice objects') - indexer = index.slice_indexer(_sanitize_slice_element(label.start), - _sanitize_slice_element(label.stop), - _sanitize_slice_element(label.step)) + "cannot use ``method`` argument if any indexers are " "slice objects" + ) + indexer = index.slice_indexer( + _sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step), + ) if not isinstance(indexer, slice): # unlike pandas, in xarray we never want to silently convert a # slice indexer into an array indexer - raise KeyError('cannot represent labeled-based slice indexer for ' - 'dimension %r with a slice over integer positions; ' - 'the index is unsorted or non-unique' % index_name) + raise KeyError( + "cannot represent labeled-based slice indexer for " + "dimension %r with a slice over integer positions; " + "the index is unsorted or non-unique" % index_name + ) elif is_dict_like(label): is_nested_vals = _is_nested_tuple(tuple(label.values())) if not isinstance(index, pd.MultiIndex): - raise ValueError('cannot use a dict-like object for selection on ' - 'a dimension that does not have a MultiIndex') + raise ValueError( + "cannot use a dict-like object for selection on " + "a dimension that does not have a MultiIndex" + ) elif len(label) == index.nlevels and not is_nested_vals: indexer = index.get_loc(tuple(label[k] for k in index.names)) else: for k, v in label.items(): # index should be an item (i.e. Hashable) not an array-like if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError('Vectorized selection is not ' - 'available along level variable: ' + k) + raise ValueError( + "Vectorized selection is not " + "available along level variable: " + k + ) indexer, new_index = index.get_loc_level( - tuple(label.values()), level=tuple(label.keys())) + tuple(label.values()), level=tuple(label.keys()) + ) # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == 'b' and indexer.sum() == 0: - raise KeyError('{} not found'.format(label)) + if indexer.dtype.kind == "b" and indexer.sum() == 0: + raise KeyError("{} not found".format(label)) elif isinstance(label, tuple) and isinstance(index, pd.MultiIndex): if _is_nested_tuple(label): @@ -173,23 +183,27 @@ def convert_label_indexer(index, label, index_name='', method=None, label, level=list(range(len(label))) ) else: - label = (label if getattr(label, 'ndim', 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label)) + label = ( + label + if getattr(label, "ndim", 1) > 1 # vectorized-indexing + else _asarray_tuplesafe(label) + ) if label.ndim == 0: if isinstance(index, pd.MultiIndex): indexer, new_index = index.get_loc_level(label.item(), level=0) else: indexer = get_loc(index, label.item(), method, tolerance) - elif label.dtype.kind == 'b': + elif label.dtype.kind == "b": indexer = label else: if isinstance(index, pd.MultiIndex) and label.ndim > 1: - raise ValueError('Vectorized selection is not available along ' - 'MultiIndex variable: ' + index_name) + raise ValueError( + "Vectorized selection is not available along " + "MultiIndex variable: " + index_name + ) indexer = get_indexer_nd(index, label, method, tolerance) if np.any(indexer < 0): - raise KeyError('not all values found in index %r' - % index_name) + raise KeyError("not all values found in index %r" % index_name) return indexer, new_index @@ -201,11 +215,13 @@ def get_dim_indexers(data_obj, indexers): into a single, dictionary indexer for that dimension (Raise a ValueError if it is not possible). """ - invalid = [k for k in indexers - if k not in data_obj.dims and k not in data_obj._level_coords] + invalid = [ + k + for k in indexers + if k not in data_obj.dims and k not in data_obj._level_coords + ] if invalid: - raise ValueError("dimensions or multi-index levels %r do not exist" - % invalid) + raise ValueError("dimensions or multi-index levels %r do not exist" % invalid) level_indexers = defaultdict(dict) dim_indexers = {} @@ -219,8 +235,10 @@ def get_dim_indexers(data_obj, indexers): for dim, level_labels in level_indexers.items(): if dim_indexers.get(dim, False): - raise ValueError("cannot combine multi-index level indexers " - "with an indexer for dimension %s" % dim) + raise ValueError( + "cannot combine multi-index level indexers " + "with an indexer for dimension %s" % dim + ) dim_indexers[dim] = level_labels return dim_indexers @@ -232,7 +250,7 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): pandas index objects (in case of multi-index level drop). """ if method is not None and not isinstance(method, str): - raise TypeError('``method`` must be a string') + raise TypeError("``method`` must be a string") pos_indexers = {} new_indexes = {} @@ -244,13 +262,14 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): except KeyError: # no index for this dimension: reuse the provided labels if method is not None or tolerance is not None: - raise ValueError('cannot supply ``method`` or ``tolerance`` ' - 'when the indexed dimension does not have ' - 'an associated coordinate.') + raise ValueError( + "cannot supply ``method`` or ``tolerance`` " + "when the indexed dimension does not have " + "an associated coordinate." + ) pos_indexers[dim] = label else: - idxr, new_idx = convert_label_indexer(index, label, - dim, method, tolerance) + idxr, new_idx = convert_label_indexer(index, label, dim, method, tolerance) pos_indexers[dim] = idxr if new_idx is not None: new_indexes[dim] = new_idx @@ -308,7 +327,7 @@ class ExplicitIndexer: def __init__(self, key): if type(self) is ExplicitIndexer: # noqa - raise TypeError('cannot instantiate base ExplicitIndexer objects') + raise TypeError("cannot instantiate base ExplicitIndexer objects") self._key = tuple(key) @property @@ -316,7 +335,7 @@ def tuple(self): return self._key def __repr__(self): - return '{}({})'.format(type(self).__name__, self.tuple) + return "{}({})".format(type(self).__name__, self.tuple) def as_integer_or_none(value): @@ -340,7 +359,7 @@ class BasicIndexer(ExplicitIndexer): def __init__(self, key): if not isinstance(key, tuple): - raise TypeError('key must be a tuple: {!r}'.format(key)) + raise TypeError("key must be a tuple: {!r}".format(key)) new_key = [] for k in key: @@ -349,8 +368,11 @@ def __init__(self, key): elif isinstance(k, slice): k = as_integer_slice(k) else: - raise TypeError('unexpected indexer type for {}: {!r}' - .format(type(self).__name__, k)) + raise TypeError( + "unexpected indexer type for {}: {!r}".format( + type(self).__name__, k + ) + ) new_key.append(k) super().__init__(new_key) @@ -367,7 +389,7 @@ class OuterIndexer(ExplicitIndexer): def __init__(self, key): if not isinstance(key, tuple): - raise TypeError('key must be a tuple: {!r}'.format(key)) + raise TypeError("key must be a tuple: {!r}".format(key)) new_key = [] for k in key: @@ -377,16 +399,22 @@ def __init__(self, key): k = as_integer_slice(k) elif isinstance(k, np.ndarray): if not np.issubdtype(k.dtype, np.integer): - raise TypeError('invalid indexer array, does not have ' - 'integer dtype: {!r}'.format(k)) + raise TypeError( + "invalid indexer array, does not have " + "integer dtype: {!r}".format(k) + ) if k.ndim != 1: - raise TypeError('invalid indexer array for {}, must have ' - 'exactly 1 dimension: ' - .format(type(self).__name__, k)) + raise TypeError( + "invalid indexer array for {}, must have " + "exactly 1 dimension: ".format(type(self).__name__, k) + ) k = np.asarray(k, dtype=np.int64) else: - raise TypeError('unexpected indexer type for {}: {!r}' - .format(type(self).__name__, k)) + raise TypeError( + "unexpected indexer type for {}: {!r}".format( + type(self).__name__, k + ) + ) new_key.append(k) super().__init__(new_key) @@ -404,7 +432,7 @@ class VectorizedIndexer(ExplicitIndexer): def __init__(self, key): if not isinstance(key, tuple): - raise TypeError('key must be a tuple: {!r}'.format(key)) + raise TypeError("key must be a tuple: {!r}".format(key)) new_key = [] ndim = None @@ -413,19 +441,25 @@ def __init__(self, key): k = as_integer_slice(k) elif isinstance(k, np.ndarray): if not np.issubdtype(k.dtype, np.integer): - raise TypeError('invalid indexer array, does not have ' - 'integer dtype: {!r}'.format(k)) + raise TypeError( + "invalid indexer array, does not have " + "integer dtype: {!r}".format(k) + ) if ndim is None: ndim = k.ndim elif ndim != k.ndim: ndims = [k.ndim for k in key if isinstance(k, np.ndarray)] - raise ValueError('invalid indexer key: ndarray arguments ' - 'have different numbers of dimensions: {}' - .format(ndims)) + raise ValueError( + "invalid indexer key: ndarray arguments " + "have different numbers of dimensions: {}".format(ndims) + ) k = np.asarray(k, dtype=np.int64) else: - raise TypeError('unexpected indexer type for {}: {!r}' - .format(type(self).__name__, k)) + raise TypeError( + "unexpected indexer type for {}: {!r}".format( + type(self).__name__, k + ) + ) new_key.append(k) super().__init__(new_key) @@ -436,7 +470,6 @@ class ExplicitlyIndexed: class ExplicitlyIndexedNDArrayMixin(utils.NDArrayMixin, ExplicitlyIndexed): - def __array__(self, dtype=None): key = BasicIndexer((slice(None),) * self.ndim) return np.asarray(self[key], dtype=dtype) @@ -498,7 +531,7 @@ def _updated_key(self, new_key): full_key.append(_index_indexer_1d(k, next(iter_new_key), size)) full_key = tuple(full_key) - if all(isinstance(k, integer_types + (slice, )) for k in full_key): + if all(isinstance(k, integer_types + (slice,)) for k in full_key): return BasicIndexer(full_key) return OuterIndexer(full_key) @@ -517,8 +550,7 @@ def __array__(self, dtype=None): return np.asarray(array[self.key], dtype=None) def transpose(self, order): - return LazilyVectorizedIndexedArray( - self.array, self.key).transpose(order) + return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order) def __getitem__(self, indexer): if isinstance(indexer, VectorizedIndexer): @@ -529,14 +561,14 @@ def __getitem__(self, indexer): def __setitem__(self, key, value): if isinstance(key, VectorizedIndexer): raise NotImplementedError( - 'Lazy item assignment with the vectorized indexer is not yet ' - 'implemented. Load your data first by .load() or compute().') + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) full_key = self._updated_key(key) self.array[full_key] = value def __repr__(self): - return ('%s(array=%r, key=%r)' % - (type(self).__name__, self.array, self.key)) + return "%s(array=%r, key=%r)" % (type(self).__name__, self.array, self.key) class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): @@ -575,18 +607,17 @@ def __getitem__(self, indexer): return type(self)(self.array, self._updated_key(indexer)) def transpose(self, order): - key = VectorizedIndexer(tuple( - k.transpose(order) for k in self.key.tuple)) + key = VectorizedIndexer(tuple(k.transpose(order) for k in self.key.tuple)) return type(self)(self.array, key) def __setitem__(self, key, value): raise NotImplementedError( - 'Lazy item assignment with the vectorized indexer is not yet ' - 'implemented. Load your data first by .load() or compute().') + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) def __repr__(self): - return ('%s(array=%r, key=%r)' % - (type(self).__name__, self.array, self.key)) + return "%s(array=%r, key=%r)" % (type(self).__name__, self.array, self.key) def _wrap_numpy_scalars(array): @@ -657,10 +688,10 @@ def as_indexable(array): return PandasIndexAdapter(array) if isinstance(array, dask_array_type): return DaskIndexingAdapter(array) - if hasattr(array, '__array_function__'): + if hasattr(array, "__array_function__"): return NdArrayLikeIndexingAdapter(array) - raise TypeError('Invalid array type: {}'.format(type(array))) + raise TypeError("Invalid array type: {}".format(type(array))) def _outer_to_vectorized_indexer(key, shape): @@ -691,9 +722,8 @@ def _outer_to_vectorized_indexer(key, shape): else: # np.ndarray or slice if isinstance(k, slice): k = np.arange(*k.indices(size)) - assert k.dtype.kind in {'i', 'u'} - shape = [(1,) * i_dim + (k.size, ) + - (1,) * (n_dim - i_dim - 1)] + assert k.dtype.kind in {"i", "u"} + shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)] new_key.append(k.reshape(*shape)) i_dim += 1 return VectorizedIndexer(tuple(new_key)) @@ -746,23 +776,23 @@ def _combine_indexers(old_key, shape, new_key): else: new_key = _outer_to_vectorized_indexer(new_key, new_shape) - return VectorizedIndexer(tuple(o[new_key.tuple] for o in - np.broadcast_arrays(*old_key.tuple))) + return VectorizedIndexer( + tuple(o[new_key.tuple] for o in np.broadcast_arrays(*old_key.tuple)) + ) class IndexingSupport: # could inherit from enum.Enum on Python 3 # for backends that support only basic indexer - BASIC = 'BASIC' + BASIC = "BASIC" # for backends that support basic / outer indexer - OUTER = 'OUTER' + OUTER = "OUTER" # for backends that support outer indexer including at most 1 vector. - OUTER_1VECTOR = 'OUTER_1VECTOR' + OUTER_1VECTOR = "OUTER_1VECTOR" # for backends that support full vectorized indexer. - VECTORIZED = 'VECTORIZED' + VECTORIZED = "VECTORIZED" -def explicit_indexing_adapter( - key, shape, indexing_support, raw_indexing_method): +def explicit_indexing_adapter(key, shape, indexing_support, raw_indexing_method): """Support explicit indexing by delegating to a raw indexing method. Outer and/or vectorized indexers are supported by indexing a second time @@ -797,7 +827,7 @@ def decompose_indexer(indexer, shape, indexing_support): return _decompose_vectorized_indexer(indexer, shape, indexing_support) if isinstance(indexer, (BasicIndexer, OuterIndexer)): return _decompose_outer_indexer(indexer, shape, indexing_support) - raise TypeError('unexpected key type: {}'.format(indexer)) + raise TypeError("unexpected key type: {}".format(indexer)) def _decompose_slice(key, size): @@ -855,8 +885,10 @@ def _decompose_vectorized_indexer(indexer, shape, indexing_support): backend_indexer = [] np_indexer = [] # convert negative indices - indexer = [np.where(k < 0, k + s, k) if isinstance(k, np.ndarray) else k - for k, s in zip(indexer.tuple, shape)] + indexer = [ + np.where(k < 0, k + s, k) if isinstance(k, np.ndarray) else k + for k, s in zip(indexer.tuple, shape) + ] for k, s in zip(indexer, shape): if isinstance(k, slice): @@ -882,7 +914,8 @@ def _decompose_vectorized_indexer(indexer, shape, indexing_support): # If the backend does not support outer indexing, # backend_indexer (OuterIndexer) is also decomposed. backend_indexer, np_indexer1 = _decompose_outer_indexer( - backend_indexer, shape, indexing_support) + backend_indexer, shape, indexing_support + ) np_indexer = _combine_indexers(np_indexer1, shape, np_indexer) return backend_indexer, np_indexer @@ -938,8 +971,12 @@ def _decompose_outer_indexer(indexer, shape, indexing_support): if indexing_support is IndexingSupport.OUTER_1VECTOR: # some backends such as h5py supports only 1 vector in indexers # We choose the most efficient axis - gains = [(np.max(k) - np.min(k) + 1.0) / len(np.unique(k)) - if isinstance(k, np.ndarray) else 0 for k in indexer] + gains = [ + (np.max(k) - np.min(k) + 1.0) / len(np.unique(k)) + if isinstance(k, np.ndarray) + else 0 + for k in indexer + ] array_index = np.argmax(np.array(gains)) if len(gains) > 0 else None for i, (k, s) in enumerate(zip(indexer, shape)): @@ -960,8 +997,7 @@ def _decompose_outer_indexer(indexer, shape, indexing_support): backend_indexer.append(bk_slice) np_indexer.append(np_slice) - return (OuterIndexer(tuple(backend_indexer)), - OuterIndexer(tuple(np_indexer))) + return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) if indexing_support == IndexingSupport.OUTER: for k, s in zip(indexer, shape): @@ -981,8 +1017,7 @@ def _decompose_outer_indexer(indexer, shape, indexing_support): backend_indexer.append(oind) np_indexer.append(vind.reshape(*k.shape)) - return (OuterIndexer(tuple(backend_indexer)), - OuterIndexer(tuple(np_indexer))) + return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) # basic indexer assert indexing_support == IndexingSupport.BASIC @@ -1000,8 +1035,7 @@ def _decompose_outer_indexer(indexer, shape, indexing_support): backend_indexer.append(bk_slice) np_indexer.append(np_slice) - return (BasicIndexer(tuple(backend_indexer)), - OuterIndexer(tuple(np_indexer))) + return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) def _arrayize_vectorized_indexer(indexer, shape): @@ -1016,10 +1050,9 @@ def _arrayize_vectorized_indexer(indexer, shape): new_key = [] for v, size in zip(indexer.tuple, shape): if isinstance(v, np.ndarray): - new_key.append(np.reshape(v, v.shape + (1, ) * len(slices))) + new_key.append(np.reshape(v, v.shape + (1,) * len(slices))) else: # slice - shape = ((1,) * (n_dim + i_dim) + (-1,) + - (1,) * (len(slices) - i_dim - 1)) + shape = (1,) * (n_dim + i_dim) + (-1,) + (1,) * (len(slices) - i_dim - 1) new_key.append(np.arange(*v.indices(size)).reshape(shape)) i_dim += 1 return VectorizedIndexer(tuple(new_key)) @@ -1028,8 +1061,9 @@ def _arrayize_vectorized_indexer(indexer, shape): def _dask_array_with_chunks_hint(array, chunks): """Create a dask array using the chunks hint for dimensions of size > 1.""" import dask.array as da + if len(chunks) < array.ndim: - raise ValueError('not enough chunks in hint') + raise ValueError("not enough chunks in hint") new_chunks = [] for chunk, size in zip(chunks, array.shape): new_chunks.append(chunk if size > 1 else (1,)) @@ -1043,9 +1077,12 @@ def _logical_any(args): def _masked_result_drop_slice(key, chunks_hint=None): key = (k for k in key if not isinstance(k, slice)) if chunks_hint is not None: - key = [_dask_array_with_chunks_hint(k, chunks_hint) - if isinstance(k, np.ndarray) else k - for k in key] + key = [ + _dask_array_with_chunks_hint(k, chunks_hint) + if isinstance(k, np.ndarray) + else k + for k in key + ] return _logical_any(k == -1 for k in key) @@ -1078,19 +1115,19 @@ def create_mask(indexer, shape, chunks_hint=None): elif isinstance(indexer, VectorizedIndexer): key = indexer.tuple base_mask = _masked_result_drop_slice(key, chunks_hint) - slice_shape = tuple(np.arange(*k.indices(size)).size - for k, size in zip(key, shape) - if isinstance(k, slice)) - expanded_mask = base_mask[ - (Ellipsis,) + (np.newaxis,) * len(slice_shape)] - mask = duck_array_ops.broadcast_to( - expanded_mask, base_mask.shape + slice_shape) + slice_shape = tuple( + np.arange(*k.indices(size)).size + for k, size in zip(key, shape) + if isinstance(k, slice) + ) + expanded_mask = base_mask[(Ellipsis,) + (np.newaxis,) * len(slice_shape)] + mask = duck_array_ops.broadcast_to(expanded_mask, base_mask.shape + slice_shape) elif isinstance(indexer, BasicIndexer): mask = any(k == -1 for k in indexer.tuple) else: - raise TypeError('unexpected key type: {}'.format(type(indexer))) + raise TypeError("unexpected key type: {}".format(type(indexer))) return mask @@ -1138,9 +1175,12 @@ def posify_mask_indexer(indexer): Same type of input, with all values in ndarray keys equal to -1 replaced by an adjacent non-masked element. """ - key = tuple(_posify_mask_subindexer(k.ravel()).reshape(k.shape) - if isinstance(k, np.ndarray) else k - for k in indexer.tuple) + key = tuple( + _posify_mask_subindexer(k.ravel()).reshape(k.shape) + if isinstance(k, np.ndarray) + else k + for k in indexer.tuple + ) return type(indexer)(key) @@ -1150,8 +1190,10 @@ class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): def __init__(self, array): # In NumpyIndexingAdapter we only allow to store bare np.ndarray if not isinstance(array, np.ndarray): - raise TypeError('NumpyIndexingAdapter only wraps np.ndarray. ' - 'Trying to wrap {}'.format(type(array))) + raise TypeError( + "NumpyIndexingAdapter only wraps np.ndarray. " + "Trying to wrap {}".format(type(array)) + ) self.array = array def _indexing_array_and_key(self, key): @@ -1168,7 +1210,7 @@ def _indexing_array_and_key(self, key): # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#detailed-notes). # noqa key = key.tuple + (Ellipsis,) else: - raise TypeError('unexpected key type: {}'.format(type(key))) + raise TypeError("unexpected key type: {}".format(type(key))) return array, key @@ -1186,18 +1228,20 @@ def __setitem__(self, key, value): except ValueError: # More informative exception if read-only view if not array.flags.writeable and not array.flags.owndata: - raise ValueError("Assignment destination is a view. " - "Do you want to .copy() array first?") + raise ValueError( + "Assignment destination is a view. " + "Do you want to .copy() array first?" + ) else: raise class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): def __init__(self, array): - if not hasattr(array, '__array_function__'): + if not hasattr(array, "__array_function__"): raise TypeError( - 'NdArrayLikeIndexingAdapter must wrap an object that ' - 'implements the __array_function__ protocol' + "NdArrayLikeIndexingAdapter must wrap an object that " + "implements the __array_function__ protocol" ) self.array = array @@ -1230,11 +1274,13 @@ def __getitem__(self, key): return value def __setitem__(self, key, value): - raise TypeError("this variable's data is stored in a dask array, " - 'which does not support item assignment. To ' - 'assign to this variable, you must first load it ' - 'into memory explicitly using the .load() ' - 'method or accessing its .values attribute.') + raise TypeError( + "this variable's data is stored in a dask array, " + "which does not support item assignment. To " + "assign to this variable, you must first load it " + "into memory explicitly using the .load() " + "method or accessing its .values attribute." + ) def transpose(self, order): return self.array.transpose(order) @@ -1248,12 +1294,12 @@ def __init__(self, array: Any, dtype: DTypeLike = None): self.array = utils.safe_cast_to_index(array) if dtype is None: if isinstance(array, pd.PeriodIndex): - dtype = np.dtype('O') - elif hasattr(array, 'categories'): + dtype = np.dtype("O") + elif hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype elif not utils.is_valid_numpy_dtype(array.dtype): - dtype = np.dtype('O') + dtype = np.dtype("O") else: dtype = array.dtype else: @@ -1271,7 +1317,7 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: if isinstance(array, pd.PeriodIndex): with suppress(AttributeError): # this might not be public API - array = array.astype('object') + array = array.astype("object") return np.asarray(array.values, dtype=dtype) @property @@ -1280,20 +1326,15 @@ def shape(self) -> Tuple[int]: return (len(self.array),) def __getitem__( - self, indexer - ) -> Union[ - NumpyIndexingAdapter, - np.ndarray, - np.datetime64, - np.timedelta64, - ]: + self, indexer + ) -> Union[NumpyIndexingAdapter, np.ndarray, np.datetime64, np.timedelta64]: key = indexer.tuple if isinstance(key, tuple) and len(key) == 1: # unpack key so it can index a pandas.Index object (pandas.Index # objects don't like tuples) key, = key - if getattr(key, 'ndim', 0) > 1: # Return np-array if multidimensional + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional return NumpyIndexingAdapter(self.array.values)[indexer] result = self.array[key] @@ -1307,9 +1348,9 @@ def __getitem__( # note: it probably would be better in general to return # pd.Timestamp rather np.than datetime64 but this is easier # (for now) - result = np.datetime64('NaT', 'ns') + result = np.datetime64("NaT", "ns") elif isinstance(result, timedelta): - result = np.timedelta64(getattr(result, 'value', result), 'ns') + result = np.timedelta64(getattr(result, "value", result), "ns") elif isinstance(result, pd.Timestamp): # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 # numpy fails to convert pd.Timestamp to np.datetime64[ns] @@ -1327,10 +1368,9 @@ def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional def __repr__(self) -> str: - return ('%s(array=%r, dtype=%r)' - % (type(self).__name__, self.array, self.dtype)) + return "%s(array=%r, dtype=%r)" % (type(self).__name__, self.array, self.dtype) - def copy(self, deep: bool = True) -> 'PandasIndexAdapter': + def copy(self, deep: bool = True) -> "PandasIndexAdapter": # Not the same as just writing `self.array.copy(deep=deep)`, as # shallow copies of the underlying numpy.ndarrays become deep ones # upon pickling diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 289b70ed518..b8d9e1a795c 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -11,7 +11,7 @@ Set, Tuple, Union, - TYPE_CHECKING + TYPE_CHECKING, ) import pandas as pd @@ -19,41 +19,36 @@ from . import dtypes, pdcompat from .alignment import deep_align from .utils import Frozen -from .variable import ( - Variable, as_variable, assert_unique_multiindex_level_names) +from .variable import Variable, as_variable, assert_unique_multiindex_level_names if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset DatasetLikeValue = Union[ - DataArray, - Variable, - Tuple[Hashable, Any], - Tuple[Sequence[Hashable], Any], + DataArray, Variable, Tuple[Hashable, Any], Tuple[Sequence[Hashable], Any] ] DatasetLike = Union[Dataset, Mapping[Hashable, DatasetLikeValue]] """Any object type that can be used on the rhs of Dataset.update, Dataset.merge, etc. """ - MutableDatasetLike = Union[ - Dataset, - MutableMapping[Hashable, DatasetLikeValue], - ] + MutableDatasetLike = Union[Dataset, MutableMapping[Hashable, DatasetLikeValue]] PANDAS_TYPES = (pd.Series, pd.DataFrame, pdcompat.Panel) -_VALID_COMPAT = Frozen({'identical': 0, - 'equals': 1, - 'broadcast_equals': 2, - 'minimal': 3, - 'no_conflicts': 4}) +_VALID_COMPAT = Frozen( + { + "identical": 0, + "equals": 1, + "broadcast_equals": 2, + "minimal": 3, + "no_conflicts": 4, + } +) -def broadcast_dimension_size( - variables: List[Variable], -) -> 'OrderedDict[Any, int]': +def broadcast_dimension_size(variables: List[Variable],) -> "OrderedDict[Any, int]": """Extract dimension sizes from a dictionary of variables. Raises ValueError if any dimensions have different sizes. @@ -62,7 +57,7 @@ def broadcast_dimension_size( for var in variables: for dim, size in zip(var.dims, var.shape): if dim in dims and size != dims[dim]: - raise ValueError('index %r not aligned' % dim) + raise ValueError("index %r not aligned" % dim) dims[dim] = size return dims @@ -70,11 +65,12 @@ def broadcast_dimension_size( class MergeError(ValueError): """Error class for merge failures due to incompatible arguments. """ + # inherits from ValueError for backward compatibility # TODO: move this to an xarray.exceptions module? -def unique_variable(name, variables, compat='broadcast_equals'): +def unique_variable(name, variables, compat="broadcast_equals"): # type: (Any, List[Variable], str) -> Variable """Return the unique variable from a list of variables or raise MergeError. @@ -100,22 +96,23 @@ def unique_variable(name, variables, compat='broadcast_equals'): if len(variables) > 1: combine_method = None - if compat == 'minimal': - compat = 'broadcast_equals' + if compat == "minimal": + compat = "broadcast_equals" - if compat == 'broadcast_equals': + if compat == "broadcast_equals": dim_lengths = broadcast_dimension_size(variables) out = out.set_dims(dim_lengths) - if compat == 'no_conflicts': - combine_method = 'fillna' + if compat == "no_conflicts": + combine_method = "fillna" for var in variables[1:]: if not getattr(out, compat)(var): - raise MergeError('conflicting values for variable %r on ' - 'objects to be combined:\n' - 'first value: %r\nsecond value: %r' - % (name, out, var)) + raise MergeError( + "conflicting values for variable %r on " + "objects to be combined:\n" + "first value: %r\nsecond value: %r" % (name, out, var) + ) if combine_method: # TODO: add preservation of attrs into fillna out = getattr(out, combine_method)(var) @@ -126,8 +123,7 @@ def unique_variable(name, variables, compat='broadcast_equals'): def _assert_compat_valid(compat): if compat not in _VALID_COMPAT: - raise ValueError("compat=%r invalid: must be %s" - % (compat, set(_VALID_COMPAT))) + raise ValueError("compat=%r invalid: must be %s" % (compat, set(_VALID_COMPAT))) class OrderedDefaultDict(OrderedDict): @@ -143,10 +139,10 @@ def __missing__(self, key): def merge_variables( - list_of_variables_dicts: List[Mapping[Any, Variable]], - priority_vars: Mapping[Any, Variable] = None, - compat: str = 'minimal', -) -> 'OrderedDict[Any, Variable]': + list_of_variables_dicts: List[Mapping[Any, Variable]], + priority_vars: Mapping[Any, Variable] = None, + compat: str = "minimal", +) -> "OrderedDict[Any, Variable]": """Merge dicts of variables, while resolving conflicts appropriately. Parameters @@ -169,7 +165,7 @@ def merge_variables( priority_vars = {} _assert_compat_valid(compat) - dim_compat = min(compat, 'equals', key=_VALID_COMPAT.get) + dim_compat = min(compat, "equals", key=_VALID_COMPAT.get) lookup = OrderedDefaultDict(list) for variables in list_of_variables_dicts: @@ -196,7 +192,7 @@ def merge_variables( try: merged[name] = unique_variable(name, var_list, compat) except MergeError: - if compat != 'minimal': + if compat != "minimal": # we need more than "minimal" compatibility (for which # we drop conflicting coordinates) raise @@ -205,8 +201,8 @@ def merge_variables( def expand_variable_dicts( - list_of_variable_dicts: 'List[Union[Dataset, OrderedDict]]', -) -> 'List[Mapping[Any, Variable]]': + list_of_variable_dicts: "List[Union[Dataset, OrderedDict]]", +) -> "List[Mapping[Any, Variable]]": """Given a list of dicts with xarray object values, expand the values. Parameters @@ -255,7 +251,7 @@ def expand_variable_dicts( def determine_coords( - list_of_variable_dicts: Iterable['DatasetLike'] + list_of_variable_dicts: Iterable["DatasetLike"] ) -> Tuple[Set[Hashable], Set[Hashable]]: """Given a list of dicts with xarray object values, identify coordinates. @@ -292,9 +288,7 @@ def determine_coords( return coord_names, noncoord_names -def coerce_pandas_values( - objects: Iterable['DatasetLike'] -) -> List['DatasetLike']: +def coerce_pandas_values(objects: Iterable["DatasetLike"]) -> List["DatasetLike"]: """Convert pandas values found in a list of labeled objects. Parameters @@ -338,7 +332,7 @@ def merge_coords_for_inplace_math(objs, priority_vars=None): return variables -def _get_priority_vars(objects, priority_arg, compat='equals'): +def _get_priority_vars(objects, priority_arg, compat="equals"): """Extract the priority variable from a list of mappings. We need this method because in some cases the priority argument itself @@ -378,8 +372,14 @@ def expand_and_merge_variables(objs, priority_arg=None): return variables -def merge_coords(objs, compat='minimal', join='outer', priority_arg=None, - indexes=None, fill_value=dtypes.NA): +def merge_coords( + objs, + compat="minimal", + join="outer", + priority_arg=None, + indexes=None, + fill_value=dtypes.NA, +): """Merge coordinate variables. See merge_core below for argument descriptions. This works similarly to @@ -388,8 +388,9 @@ def merge_coords(objs, compat='minimal', join='outer', priority_arg=None, """ _assert_compat_valid(compat) coerced = coerce_pandas_values(objs) - aligned = deep_align(coerced, join=join, copy=False, indexes=indexes, - fill_value=fill_value) + aligned = deep_align( + coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value + ) expanded = expand_variable_dicts(aligned) priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat) variables = merge_variables(expanded, priority_vars, compat=compat) @@ -398,14 +399,14 @@ def merge_coords(objs, compat='minimal', join='outer', priority_arg=None, return variables -def merge_data_and_coords(data, coords, compat='broadcast_equals', - join='outer'): +def merge_data_and_coords(data, coords, compat="broadcast_equals", join="outer"): """Used in Dataset.__init__.""" objs = [data, coords] explicit_coords = coords.keys() indexes = dict(extract_indexes(coords)) - return merge_core(objs, compat, join, explicit_coords=explicit_coords, - indexes=indexes) + return merge_core( + objs, compat, join, explicit_coords=explicit_coords, indexes=indexes + ) def extract_indexes(coords): @@ -425,22 +426,21 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): for coord_name in explicit_coords: if coord_name in dims and variables[coord_name].dims != (coord_name,): raise MergeError( - 'coordinate %s shares a name with a dataset dimension, but is ' - 'not a 1D variable along that dimension. This is disallowed ' - 'by the xarray data model.' % coord_name) + "coordinate %s shares a name with a dataset dimension, but is " + "not a 1D variable along that dimension. This is disallowed " + "by the xarray data model." % coord_name + ) def merge_core( objs, - compat='broadcast_equals', - join='outer', + compat="broadcast_equals", + join="outer", priority_arg=None, explicit_coords=None, indexes=None, - fill_value=dtypes.NA -) -> Tuple['OrderedDict[Hashable, Variable]', - Set[Hashable], - Dict[Hashable, int]]: + fill_value=dtypes.NA, +) -> Tuple["OrderedDict[Hashable, Variable]", Set[Hashable], Dict[Hashable, int]]: """Core logic for merging labeled objects. This is not public API. @@ -480,8 +480,9 @@ def merge_core( _assert_compat_valid(compat) coerced = coerce_pandas_values(objs) - aligned = deep_align(coerced, join=join, copy=False, indexes=indexes, - fill_value=fill_value) + aligned = deep_align( + coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value + ) expanded = expand_variable_dicts(aligned) coord_names, noncoord_names = determine_coords(coerced) @@ -502,14 +503,15 @@ def merge_core( ambiguous_coords = coord_names.intersection(noncoord_names) if ambiguous_coords: - raise MergeError('unable to determine if these variables should be ' - 'coordinates or not in the merged result: %s' - % ambiguous_coords) + raise MergeError( + "unable to determine if these variables should be " + "coordinates or not in the merged result: %s" % ambiguous_coords + ) return variables, coord_names, dims -def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA): +def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): """Merge any number of xarray objects into a single Dataset as variables. Parameters @@ -576,31 +578,31 @@ def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA): dict_like_objects = list() for obj in objects: if not (isinstance(obj, (DataArray, Dataset, dict))): - raise TypeError("objects must be an iterable containing only " - "Dataset(s), DataArray(s), and dictionaries.") + raise TypeError( + "objects must be an iterable containing only " + "Dataset(s), DataArray(s), and dictionaries." + ) obj = obj.to_dataset() if isinstance(obj, DataArray) else obj dict_like_objects.append(obj) - variables, coord_names, dims = merge_core(dict_like_objects, compat, join, - fill_value=fill_value) + variables, coord_names, dims = merge_core( + dict_like_objects, compat, join, fill_value=fill_value + ) # TODO: don't always recompute indexes - merged = Dataset._construct_direct( - variables, coord_names, dims, indexes=None) + merged = Dataset._construct_direct(variables, coord_names, dims, indexes=None) return merged def dataset_merge_method( - dataset: 'Dataset', - other: 'DatasetLike', + dataset: "Dataset", + other: "DatasetLike", overwrite_vars: Union[Hashable, Iterable[Hashable]], compat: str, join: str, - fill_value: Any -) -> Tuple['OrderedDict[Hashable, Variable]', - Set[Hashable], - Dict[Hashable, int]]: + fill_value: Any, +) -> Tuple["OrderedDict[Hashable, Variable]", Set[Hashable], Dict[Hashable, int]]: """Guts of the Dataset.merge method. """ @@ -608,8 +610,7 @@ def dataset_merge_method( # method due for backwards compatibility # TODO: consider deprecating it? - if isinstance(overwrite_vars, Iterable) and not isinstance( - overwrite_vars, str): + if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str): overwrite_vars = set(overwrite_vars) else: overwrite_vars = {overwrite_vars} @@ -631,16 +632,14 @@ def dataset_merge_method( objs = [dataset, other_no_overwrite, other_overwrite] priority_arg = 2 - return merge_core(objs, compat, join, priority_arg=priority_arg, - fill_value=fill_value) + return merge_core( + objs, compat, join, priority_arg=priority_arg, fill_value=fill_value + ) def dataset_update_method( - dataset: 'Dataset', - other: 'DatasetLike', -) -> Tuple['OrderedDict[Hashable, Variable]', - Set[Hashable], - Dict[Hashable, int]]: + dataset: "Dataset", other: "DatasetLike" +) -> Tuple["OrderedDict[Hashable, Variable]", Set[Hashable], Dict[Hashable, int]]: """Guts of the Dataset.update method. This drops a duplicated coordinates from `other` if `other` is not an @@ -655,10 +654,12 @@ def dataset_update_method( for key, value in other.items(): if isinstance(value, DataArray): # drop conflicting coordinates - coord_names = [c for c in value.coords - if c not in value.dims and c in dataset.coords] + coord_names = [ + c + for c in value.coords + if c not in value.dims and c in dataset.coords + ] if coord_names: other[key] = value.drop(coord_names) - return merge_core([dataset, other], priority_arg=1, - indexes=dataset.indexes) + return merge_core([dataset, other], priority_arg=1, indexes=dataset.indexes) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 14d74ef5b79..fdabdb156b6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -16,6 +16,7 @@ class BaseInterpolator: """Generic interpolator class for normalizing interpolation methods """ + cons_kwargs = None # type: Dict[str, Any] call_kwargs = None # type: Dict[str, Any] f = None # type: Callable @@ -25,8 +26,9 @@ def __call__(self, x): return self.f(x, **self.call_kwargs) def __repr__(self): - return "{type}: method={method}".format(type=self.__class__.__name__, - method=self.method) + return "{type}: method={method}".format( + type=self.__class__.__name__, method=self.method + ) class NumpyInterpolator(BaseInterpolator): @@ -36,16 +38,16 @@ class NumpyInterpolator(BaseInterpolator): -------- numpy.interp """ - def __init__(self, xi, yi, method='linear', fill_value=None, period=None): - if method != 'linear': - raise ValueError( - 'only method `linear` is valid for the NumpyInterpolator') + def __init__(self, xi, yi, method="linear", fill_value=None, period=None): + + if method != "linear": + raise ValueError("only method `linear` is valid for the NumpyInterpolator") self.method = method self.f = np.interp self.cons_kwargs = {} - self.call_kwargs = {'period': period} + self.call_kwargs = {"period": period} self._xi = xi self._yi = yi @@ -60,11 +62,17 @@ def __init__(self, xi, yi, method='linear', fill_value=None, period=None): self._left = fill_value self._right = fill_value else: - raise ValueError('%s is not a valid fill_value' % fill_value) + raise ValueError("%s is not a valid fill_value" % fill_value) def __call__(self, x): - return self.f(x, self._xi, self._yi, left=self._left, - right=self._right, **self.call_kwargs) + return self.f( + x, + self._xi, + self._yi, + left=self._left, + right=self._right, + **self.call_kwargs + ) class ScipyInterpolator(BaseInterpolator): @@ -74,18 +82,30 @@ class ScipyInterpolator(BaseInterpolator): -------- scipy.interpolate.interp1d """ - def __init__(self, xi, yi, method=None, fill_value=None, - assume_sorted=True, copy=False, bounds_error=False, - order=None, **kwargs): + + def __init__( + self, + xi, + yi, + method=None, + fill_value=None, + assume_sorted=True, + copy=False, + bounds_error=False, + order=None, + **kwargs + ): from scipy.interpolate import interp1d if method is None: - raise ValueError('method is a required argument, please supply a ' - 'valid scipy.inter1d method (kind)') + raise ValueError( + "method is a required argument, please supply a " + "valid scipy.inter1d method (kind)" + ) - if method == 'polynomial': + if method == "polynomial": if order is None: - raise ValueError('order is required when method=polynomial') + raise ValueError("order is required when method=polynomial") method = order self.method = method @@ -93,14 +113,21 @@ def __init__(self, xi, yi, method=None, fill_value=None, self.cons_kwargs = kwargs self.call_kwargs = {} - if fill_value is None and method == 'linear': + if fill_value is None and method == "linear": fill_value = np.nan, np.nan elif fill_value is None: fill_value = np.nan - self.f = interp1d(xi, yi, kind=self.method, fill_value=fill_value, - bounds_error=False, assume_sorted=assume_sorted, - copy=copy, **self.cons_kwargs) + self.f = interp1d( + xi, + yi, + kind=self.method, + fill_value=fill_value, + bounds_error=False, + assume_sorted=assume_sorted, + copy=copy, + **self.cons_kwargs + ) class SplineInterpolator(BaseInterpolator): @@ -110,20 +137,29 @@ class SplineInterpolator(BaseInterpolator): -------- scipy.interpolate.UnivariateSpline """ - def __init__(self, xi, yi, method='spline', fill_value=None, order=3, - nu=0, ext=None, **kwargs): + + def __init__( + self, + xi, + yi, + method="spline", + fill_value=None, + order=3, + nu=0, + ext=None, + **kwargs + ): from scipy.interpolate import UnivariateSpline - if method != 'spline': - raise ValueError( - 'only method `spline` is valid for the SplineInterpolator') + if method != "spline": + raise ValueError("only method `spline` is valid for the SplineInterpolator") self.method = method self.cons_kwargs = kwargs - self.call_kwargs = {'nu': nu, 'ext': ext} + self.call_kwargs = {"nu": nu, "ext": ext} if fill_value is not None: - raise ValueError('SplineInterpolator does not support fill_value') + raise ValueError("SplineInterpolator does not support fill_value") self.f = UnivariateSpline(xi, yi, k=order, **self.cons_kwargs) @@ -143,7 +179,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): def get_clean_interp_index(arr, dim, use_coordinate=True): - '''get index to use for x values in interpolation. + """get index to use for x values in interpolation. If use_coordinate is True, the coordinate that shares the name of the dimension along which interpolation is being performed will be used as the @@ -151,7 +187,7 @@ def get_clean_interp_index(arr, dim, use_coordinate=True): If use_coordinate is False, the x values are set as an equally spaced sequence. - ''' + """ if use_coordinate: if use_coordinate is True: index = arr.get_index(dim) @@ -159,8 +195,9 @@ def get_clean_interp_index(arr, dim, use_coordinate=True): index = arr.coords[use_coordinate] if index.ndim != 1: raise ValueError( - 'Coordinates used for interpolation must be 1D, ' - '%s is %dD.' % (use_coordinate, index.ndim)) + "Coordinates used for interpolation must be 1D, " + "%s is %dD." % (use_coordinate, index.ndim) + ) # raise if index cannot be cast to a float (e.g. MultiIndex) try: @@ -168,8 +205,10 @@ def get_clean_interp_index(arr, dim, use_coordinate=True): except (TypeError, ValueError): # pandas raises a TypeError # xarray/nuppy raise a ValueError - raise TypeError('Index must be castable to float64 to support' - 'interpolation, got: %s' % type(index)) + raise TypeError( + "Index must be castable to float64 to support" + "interpolation, got: %s" % type(index) + ) # check index sorting now so we can skip it later if not (np.diff(index) > 0).all(): raise ValueError("Index must be monotonicly increasing") @@ -180,12 +219,13 @@ def get_clean_interp_index(arr, dim, use_coordinate=True): return index -def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, - **kwargs): +def interp_na( + self, dim=None, use_coordinate=True, method="linear", limit=None, **kwargs +): """Interpolate values according to different methods. """ if dim is None: - raise NotImplementedError('dim is a required argument') + raise NotImplementedError("dim is a required argument") if limit is not None: valids = _get_valid_fill_mask(self, dim, limit) @@ -196,15 +236,19 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, interpolator = partial(func_interpolate_na, interp_class, **kwargs) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'overflow', RuntimeWarning) - warnings.filterwarnings('ignore', 'invalid value', RuntimeWarning) - arr = apply_ufunc(interpolator, index, self, - input_core_dims=[[dim], [dim]], - output_core_dims=[[dim]], - output_dtypes=[self.dtype], - dask='parallelized', - vectorize=True, - keep_attrs=True).transpose(*self.dims) + warnings.filterwarnings("ignore", "overflow", RuntimeWarning) + warnings.filterwarnings("ignore", "invalid value", RuntimeWarning) + arr = apply_ufunc( + interpolator, + index, + self, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], + output_dtypes=[self.dtype], + dask="parallelized", + vectorize=True, + keep_attrs=True, + ).transpose(*self.dims) if limit is not None: arr = arr.where(valids) @@ -213,7 +257,7 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, def func_interpolate_na(interpolator, x, y, **kwargs): - '''helper function to apply interpolation along 1 dimension''' + """helper function to apply interpolation along 1 dimension""" # it would be nice if this wasn't necessary, works around: # "ValueError: assignment destination is read-only" in assignment below out = y.copy() @@ -232,7 +276,7 @@ def func_interpolate_na(interpolator, x, y, **kwargs): def _bfill(arr, n=None, axis=-1): - '''inverse of ffill''' + """inverse of ffill""" import bottleneck as bn arr = np.flip(arr, axis=axis) @@ -245,7 +289,7 @@ def _bfill(arr, n=None, axis=-1): def ffill(arr, dim=None, limit=None): - '''forward fill missing values''' + """forward fill missing values""" import bottleneck as bn axis = arr.get_axis_num(dim) @@ -253,36 +297,54 @@ def ffill(arr, dim=None, limit=None): # work around for bottleneck 178 _limit = limit if limit is not None else arr.shape[axis] - return apply_ufunc(bn.push, arr, - dask='parallelized', - keep_attrs=True, - output_dtypes=[arr.dtype], - kwargs=dict(n=_limit, axis=axis)).transpose(*arr.dims) + return apply_ufunc( + bn.push, + arr, + dask="parallelized", + keep_attrs=True, + output_dtypes=[arr.dtype], + kwargs=dict(n=_limit, axis=axis), + ).transpose(*arr.dims) def bfill(arr, dim=None, limit=None): - '''backfill missing values''' + """backfill missing values""" axis = arr.get_axis_num(dim) # work around for bottleneck 178 _limit = limit if limit is not None else arr.shape[axis] - return apply_ufunc(_bfill, arr, - dask='parallelized', - keep_attrs=True, - output_dtypes=[arr.dtype], - kwargs=dict(n=_limit, axis=axis)).transpose(*arr.dims) + return apply_ufunc( + _bfill, + arr, + dask="parallelized", + keep_attrs=True, + output_dtypes=[arr.dtype], + kwargs=dict(n=_limit, axis=axis), + ).transpose(*arr.dims) def _get_interpolator(method, vectorizeable_only=False, **kwargs): - '''helper function to select the appropriate interpolator class + """helper function to select the appropriate interpolator class returns interpolator class and keyword arguments for the class - ''' - interp1d_methods = ['linear', 'nearest', 'zero', 'slinear', 'quadratic', - 'cubic', 'polynomial'] - valid_methods = interp1d_methods + ['barycentric', 'krog', 'pchip', - 'spline', 'akima'] + """ + interp1d_methods = [ + "linear", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "polynomial", + ] + valid_methods = interp1d_methods + [ + "barycentric", + "krog", + "pchip", + "spline", + "akima", + ] has_scipy = True try: @@ -291,83 +353,90 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs): has_scipy = False # prioritize scipy.interpolate - if (method == 'linear' and not - kwargs.get('fill_value', None) == 'extrapolate' and - not vectorizeable_only): + if ( + method == "linear" + and not kwargs.get("fill_value", None) == "extrapolate" + and not vectorizeable_only + ): kwargs.update(method=method) interp_class = NumpyInterpolator elif method in valid_methods: if not has_scipy: - raise ImportError( - 'Interpolation with method `%s` requires scipy' % method) + raise ImportError("Interpolation with method `%s` requires scipy" % method) if method in interp1d_methods: kwargs.update(method=method) interp_class = ScipyInterpolator elif vectorizeable_only: - raise ValueError('{} is not a vectorizeable interpolator. ' - 'Available methods are {}'.format( - method, interp1d_methods)) - elif method == 'barycentric': + raise ValueError( + "{} is not a vectorizeable interpolator. " + "Available methods are {}".format(method, interp1d_methods) + ) + elif method == "barycentric": interp_class = interpolate.BarycentricInterpolator - elif method == 'krog': + elif method == "krog": interp_class = interpolate.KroghInterpolator - elif method == 'pchip': + elif method == "pchip": interp_class = interpolate.PchipInterpolator - elif method == 'spline': + elif method == "spline": kwargs.update(method=method) interp_class = SplineInterpolator - elif method == 'akima': + elif method == "akima": interp_class = interpolate.Akima1DInterpolator else: - raise ValueError('%s is not a valid scipy interpolator' % method) + raise ValueError("%s is not a valid scipy interpolator" % method) else: - raise ValueError('%s is not a valid interpolator' % method) + raise ValueError("%s is not a valid interpolator" % method) return interp_class, kwargs def _get_interpolator_nd(method, **kwargs): - '''helper function to select the appropriate interpolator class + """helper function to select the appropriate interpolator class returns interpolator class and keyword arguments for the class - ''' - valid_methods = ['linear', 'nearest'] + """ + valid_methods = ["linear", "nearest"] try: from scipy import interpolate except ImportError: - raise ImportError( - 'Interpolation with method `%s` requires scipy' % method) + raise ImportError("Interpolation with method `%s` requires scipy" % method) if method in valid_methods: kwargs.update(method=method) interp_class = interpolate.interpn else: - raise ValueError('%s is not a valid interpolator for interpolating ' - 'over multiple dimensions.' % method) + raise ValueError( + "%s is not a valid interpolator for interpolating " + "over multiple dimensions." % method + ) return interp_class, kwargs def _get_valid_fill_mask(arr, dim, limit): - '''helper function to determine values that can be filled when limit is not - None''' + """helper function to determine values that can be filled when limit is not + None""" kw = {dim: limit + 1} # we explicitly use construct method to avoid copy. - new_dim = utils.get_temp_dimname(arr.dims, '_window') - return (arr.isnull().rolling(min_periods=1, **kw) - .construct(new_dim, fill_value=False) - .sum(new_dim, skipna=False)) <= limit + new_dim = utils.get_temp_dimname(arr.dims, "_window") + return ( + arr.isnull() + .rolling(min_periods=1, **kw) + .construct(new_dim, fill_value=False) + .sum(new_dim, skipna=False) + ) <= limit def _assert_single_chunk(var, axes): for axis in axes: if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: raise NotImplementedError( - 'Chunking along the dimension to be interpolated ' - '({}) is not yet supported.'.format(axis)) + "Chunking along the dimension to be interpolated " + "({}) is not yet supported.".format(axis) + ) def _localize(var, indexes_coords): @@ -377,8 +446,8 @@ def _localize(var, indexes_coords): indexes = {} for dim, [x, new_x] in indexes_coords.items(): index = x.to_index() - imin = index.get_loc(np.min(new_x.values), method='nearest') - imax = index.get_loc(np.max(new_x.values), method='nearest') + imin = index.get_loc(np.min(new_x.values), method="nearest") + imax = index.get_loc(np.max(new_x.values), method="nearest") indexes[dim] = slice(max(imin - 2, 0), imax + 2) indexes_coords[dim] = (x[indexes[dim]], new_x) @@ -435,11 +504,11 @@ def interp(var, indexes_coords, method, **kwargs): return var.copy() # simple speed up for the local interpolation - if method in ['linear', 'nearest']: + if method in ["linear", "nearest"]: var, indexes_coords = _localize(var, indexes_coords) # default behavior - kwargs['bounds_error'] = kwargs.get('bounds_error', False) + kwargs["bounds_error"] = kwargs.get("bounds_error", False) # target dimensions dims = list(indexes_coords) @@ -450,8 +519,9 @@ def interp(var, indexes_coords, method, **kwargs): broadcast_dims = [d for d in var.dims if d not in dims] original_dims = broadcast_dims + dims new_dims = broadcast_dims + list(destination[0].dims) - interped = interp_func(var.transpose(*original_dims).data, - x, destination, method, kwargs) + interped = interp_func( + var.transpose(*original_dims).data, x, destination, method, kwargs + ) result = Variable(new_dims, interped, attrs=var.attrs) @@ -502,8 +572,7 @@ def interp_func(var, x, new_x, method, kwargs): return var.copy() if len(x) == 1: - func, kwargs = _get_interpolator(method, vectorizeable_only=True, - **kwargs) + func, kwargs = _get_interpolator(method, vectorizeable_only=True, **kwargs) else: func, kwargs = _get_interpolator_nd(method, **kwargs) @@ -511,12 +580,21 @@ def interp_func(var, x, new_x, method, kwargs): import dask.array as da _assert_single_chunk(var, range(var.ndim - len(x), var.ndim)) - chunks = var.chunks[:-len(x)] + new_x[0].shape + chunks = var.chunks[: -len(x)] + new_x[0].shape drop_axis = range(var.ndim - len(x), var.ndim) new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) - return da.map_blocks(_interpnd, var, x, new_x, func, kwargs, - dtype=var.dtype, chunks=chunks, - new_axis=new_axis, drop_axis=drop_axis) + return da.map_blocks( + _interpnd, + var, + x, + new_x, + func, + kwargs, + dtype=var.dtype, + chunks=chunks, + new_axis=new_axis, + drop_axis=drop_axis, + ) return _interpnd(var, x, new_x, func, kwargs) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 62479b55c99..9ba4eae29ae 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -1,8 +1,7 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import ( - _dask_or_eager_func, count, fillna, isnull, where_method) +from .duck_array_ops import _dask_or_eager_func, count, fillna, isnull, where_method from .pycompat import dask_array_type try: @@ -24,18 +23,19 @@ def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out """ - if hasattr(axis, '__len__'): # if tuple or list - raise ValueError('min_count is not available for reduction ' - 'with more than one dimensions.') + if hasattr(axis, "__len__"): # if tuple or list + raise ValueError( + "min_count is not available for reduction " "with more than one dimensions." + ) - if axis is not None and getattr(result, 'ndim', False): + if axis is not None and getattr(result, "ndim", False): null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 if null_mask.any(): dtype, fill_value = dtypes.maybe_promote(result.dtype) result = result.astype(dtype) result[null_mask] = fill_value - elif getattr(result, 'dtype', None) not in dtypes.NAT_TYPES: + elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: null_mask = mask.size - mask.sum() if null_mask < min_count: result = np.nan @@ -53,7 +53,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): # TODO This will evaluate dask arrays and might be costly. if (valid_count == 0).any(): - raise ValueError('All-NaN slice encountered') + raise ValueError("All-NaN slice encountered") return data @@ -63,7 +63,7 @@ def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): valid_count = count(value, axis=axis) filled_value = fillna(value, fill_value) data = getattr(np, func)(filled_value, axis=axis, **kwargs) - if not hasattr(data, 'dtype'): # scalar case + if not hasattr(data, "dtype"): # scalar case data = fill_value if valid_count == 0 else data # we've computed a single min, max value of type object. # don't let np.array turn a tuple back into an array @@ -72,18 +72,16 @@ def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): def nanmin(a, axis=None, out=None): - if a.dtype.kind == 'O': - return _nan_minmax_object( - 'min', dtypes.get_pos_infinity(a.dtype), a, axis) + if a.dtype.kind == "O": + return _nan_minmax_object("min", dtypes.get_pos_infinity(a.dtype), a, axis) module = dask_array if isinstance(a, dask_array_type) else nputils return module.nanmin(a, axis=axis) def nanmax(a, axis=None, out=None): - if a.dtype.kind == 'O': - return _nan_minmax_object( - 'max', dtypes.get_neg_infinity(a.dtype), a, axis) + if a.dtype.kind == "O": + return _nan_minmax_object("max", dtypes.get_neg_infinity(a.dtype), a, axis) module = dask_array if isinstance(a, dask_array_type) else nputils return module.nanmax(a, axis=axis) @@ -91,8 +89,8 @@ def nanmax(a, axis=None, out=None): def nanargmin(a, axis=None): fill_value = dtypes.get_pos_infinity(a.dtype) - if a.dtype.kind == 'O': - return _nan_argminmax_object('argmin', fill_value, a, axis=axis) + if a.dtype.kind == "O": + return _nan_argminmax_object("argmin", fill_value, a, axis=axis) a, mask = _replace_nan(a, fill_value) if isinstance(a, dask_array_type): res = dask_array.argmin(a, axis=axis) @@ -108,8 +106,8 @@ def nanargmin(a, axis=None): def nanargmax(a, axis=None): fill_value = dtypes.get_neg_infinity(a.dtype) - if a.dtype.kind == 'O': - return _nan_argminmax_object('argmax', fill_value, a, axis=axis) + if a.dtype.kind == "O": + return _nan_argminmax_object("argmax", fill_value, a, axis=axis) a, mask = _replace_nan(a, fill_value) if isinstance(a, dask_array_type): @@ -126,7 +124,7 @@ def nanargmax(a, axis=None): def nansum(a, axis=None, dtype=None, out=None, min_count=None): a, mask = _replace_nan(a, 0) - result = _dask_or_eager_func('sum')(a, axis=axis, dtype=dtype) + result = _dask_or_eager_func("sum")(a, axis=axis, dtype=dtype) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: @@ -135,23 +133,22 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None): def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): """ In house nanmean. ddof argument will be used in _nanvar method """ - from .duck_array_ops import (count, fillna, _dask_or_eager_func, - where_method) + from .duck_array_ops import count, fillna, _dask_or_eager_func, where_method valid_count = count(value, axis=axis) value = fillna(value, 0) # As dtype inference is impossible for object dtype, we assume float # https://github.com/dask/dask/issues/3162 - if dtype is None and value.dtype.kind == 'O': - dtype = value.dtype if value.dtype.kind in ['cf'] else float + if dtype is None and value.dtype.kind == "O": + dtype = value.dtype if value.dtype.kind in ["cf"] else float - data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs) + data = _dask_or_eager_func("sum")(value, axis=axis, dtype=dtype, **kwargs) data = data / (valid_count - ddof) return where_method(data, valid_count != 0) def nanmean(a, axis=None, dtype=None, out=None): - if a.dtype.kind == 'O': + if a.dtype.kind == "O": return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) if isinstance(a, dask_array_type): @@ -161,33 +158,35 @@ def nanmean(a, axis=None, dtype=None, out=None): def nanmedian(a, axis=None, out=None): - return _dask_or_eager_func('nanmedian', eager_module=nputils)(a, axis=axis) + return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis) def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): - value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis, - keepdims=True, **kwargs) - squared = (value.astype(value_mean.dtype) - value_mean)**2 - return _nanmean_ddof_object(ddof, squared, axis=axis, - keepdims=keepdims, **kwargs) + value_mean = _nanmean_ddof_object( + ddof=0, value=value, axis=axis, keepdims=True, **kwargs + ) + squared = (value.astype(value_mean.dtype) - value_mean) ** 2 + return _nanmean_ddof_object(ddof, squared, axis=axis, keepdims=keepdims, **kwargs) def nanvar(a, axis=None, dtype=None, out=None, ddof=0): - if a.dtype.kind == 'O': + if a.dtype.kind == "O": return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) - return _dask_or_eager_func('nanvar', eager_module=nputils)( - a, axis=axis, dtype=dtype, ddof=ddof) + return _dask_or_eager_func("nanvar", eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof + ) def nanstd(a, axis=None, dtype=None, out=None, ddof=0): - return _dask_or_eager_func('nanstd', eager_module=nputils)( - a, axis=axis, dtype=dtype, ddof=ddof) + return _dask_or_eager_func("nanstd", eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof + ) def nanprod(a, axis=None, dtype=None, out=None, min_count=None): a, mask = _replace_nan(a, 1) - result = _dask_or_eager_func('nanprod')(a, axis=axis, dtype=dtype, out=out) + result = _dask_or_eager_func("nanprod")(a, axis=axis, dtype=dtype, out=out) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: @@ -195,10 +194,12 @@ def nanprod(a, axis=None, dtype=None, out=None, min_count=None): def nancumsum(a, axis=None, dtype=None, out=None): - return _dask_or_eager_func('nancumsum', eager_module=nputils)( - a, axis=axis, dtype=dtype) + return _dask_or_eager_func("nancumsum", eager_module=nputils)( + a, axis=axis, dtype=dtype + ) def nancumprod(a, axis=None, dtype=None, out=None): - return _dask_or_eager_func('nancumprod', eager_module=nputils)( - a, axis=axis, dtype=dtype) + return _dask_or_eager_func("nancumprod", eager_module=nputils)( + a, axis=axis, dtype=dtype + ) diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index afef9a5e083..ecaadae726e 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -128,16 +128,18 @@ def isin(element, test_elements, assume_unique=False, invert=False): [ True, False]]) """ element = np.asarray(element) - return np.in1d(element, test_elements, assume_unique=assume_unique, - invert=invert).reshape(element.shape) + return np.in1d( + element, test_elements, assume_unique=assume_unique, invert=invert + ).reshape(element.shape) -if LooseVersion(np.__version__) >= LooseVersion('1.13'): +if LooseVersion(np.__version__) >= LooseVersion("1.13"): gradient = np.gradient else: + def normalize_axis_tuple(axes, N): if isinstance(axes, int): - axes = (axes, ) + axes = (axes,) return tuple([N + a if a < 0 else a for a in axes]) def gradient(f, *varargs, axis=None, edge_order=1): @@ -169,8 +171,10 @@ def gradient(f, *varargs, axis=None, edge_order=1): elif np.ndim(distances) != 1: raise ValueError("distances must be either scalars or 1d") if len(distances) != f.shape[axes[i]]: - raise ValueError("when 1d, distances must match the " - "length of the corresponding dimension") + raise ValueError( + "when 1d, distances must match the " + "length of the corresponding dimension" + ) diffx = np.diff(distances) # if distances are constant reduce to the scalar case # since it brings a consistent speedup @@ -195,15 +199,15 @@ def gradient(f, *varargs, axis=None, edge_order=1): slice4 = [slice(None)] * N otype = f.dtype.char - if otype not in ['f', 'd', 'F', 'D', 'm', 'M']: - otype = 'd' + if otype not in ["f", "d", "F", "D", "m", "M"]: + otype = "d" # Difference of datetime64 elements results in timedelta64 - if otype == 'M': + if otype == "M": # Need to use the full dtype name because it contains unit # information - otype = f.dtype.name.replace('datetime', 'timedelta') - elif otype == 'm': + otype = f.dtype.name.replace("datetime", "timedelta") + elif otype == "m": # Needs to keep the specific units, can't be a general unit otype = f.dtype @@ -211,7 +215,7 @@ def gradient(f, *varargs, axis=None, edge_order=1): # that is a view of ints if the data is datetime64, otherwise # just set y equal to the array `f`. if f.dtype.char in ["M", "m"]: - y = f.view('int64') + y = f.view("int64") else: y = f @@ -220,7 +224,8 @@ def gradient(f, *varargs, axis=None, edge_order=1): raise ValueError( "Shape of array too small to calculate a numerical " "gradient, at least (edge_order + 1) elements are " - "required.") + "required." + ) # result allocation out = np.empty_like(y, dtype=otype) @@ -233,7 +238,7 @@ def gradient(f, *varargs, axis=None, edge_order=1): slice4[axis] = slice(2, None) if uniform_spacing: - out[slice1] = (f[slice4] - f[slice2]) / (2. * dx[i]) + out[slice1] = (f[slice4] - f[slice2]) / (2.0 * dx[i]) else: dx1 = dx[i][0:-1] dx2 = dx[i][1:] @@ -272,14 +277,14 @@ def gradient(f, *varargs, axis=None, edge_order=1): slice4[axis] = 2 if uniform_spacing: a = -1.5 / dx[i] - b = 2. / dx[i] + b = 2.0 / dx[i] c = -0.5 / dx[i] else: dx1 = dx[i][0] dx2 = dx[i][1] - a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2)) + a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) b = (dx1 + dx2) / (dx1 * dx2) - c = - dx1 / (dx2 * (dx1 + dx2)) + c = -dx1 / (dx2 * (dx1 + dx2)) # 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2] out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] @@ -289,14 +294,14 @@ def gradient(f, *varargs, axis=None, edge_order=1): slice4[axis] = -1 if uniform_spacing: a = 0.5 / dx[i] - b = -2. / dx[i] + b = -2.0 / dx[i] c = 1.5 / dx[i] else: dx1 = dx[i][-2] dx2 = dx[i][-1] a = (dx2) / (dx1 * (dx1 + dx2)) - b = - (dx2 + dx1) / (dx1 * dx2) - c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2)) + b = -(dx2 + dx1) / (dx1 * dx2) + c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] @@ -323,10 +328,9 @@ def _validate_axis(axis, ndim, argname): axis = list(axis) axis = [a + ndim if a < 0 else a for a in axis] if not builtins.all(0 <= a < ndim for a in axis): - raise ValueError('invalid axis for this array in `%s` argument' % - argname) + raise ValueError("invalid axis for this array in `%s` argument" % argname) if len(set(axis)) != len(axis): - raise ValueError('repeated axis in `%s` argument' % argname) + raise ValueError("repeated axis in `%s` argument" % argname) return axis @@ -338,11 +342,13 @@ def moveaxis(a, source, destination): a = np.asarray(a) transpose = a.transpose - source = _validate_axis(source, a.ndim, 'source') - destination = _validate_axis(destination, a.ndim, 'destination') + source = _validate_axis(source, a.ndim, "source") + destination = _validate_axis(destination, a.ndim, "destination") if len(source) != len(destination): - raise ValueError('`source` and `destination` arguments must have ' - 'the same number of elements') + raise ValueError( + "`source` and `destination` arguments must have " + "the same number of elements" + ) order = [n for n in range(a.ndim) if n not in source] diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 9b1f5721a08..a9971e7125a 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -5,6 +5,7 @@ try: import bottleneck as bn + _USE_BOTTLENECK = True except ImportError: # use numpy methods instead @@ -15,8 +16,7 @@ def _validate_axis(data, axis): ndim = data.ndim if not -ndim <= axis < ndim: - raise IndexError('axis %r out of bounds [-%r, %r)' - % (axis, ndim, ndim)) + raise IndexError("axis %r out of bounds [-%r, %r)" % (axis, ndim, ndim)) if axis < 0: axis += ndim return axis @@ -77,13 +77,13 @@ def _ensure_bool_is_ndarray(result, *args): def array_eq(self, other): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'elementwise comparison failed') + warnings.filterwarnings("ignore", r"elementwise comparison failed") return _ensure_bool_is_ndarray(self == other, self, other) def array_ne(self, other): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'elementwise comparison failed') + warnings.filterwarnings("ignore", r"elementwise comparison failed") return _ensure_bool_is_ndarray(self != other, self, other) @@ -102,11 +102,11 @@ def _advanced_indexer_subspaces(key): """ if not isinstance(key, tuple): key = (key,) - advanced_index_positions = [i for i, k in enumerate(key) - if not isinstance(k, slice)] + advanced_index_positions = [ + i for i, k in enumerate(key) if not isinstance(k, slice) + ] - if (not advanced_index_positions or - not _is_contiguous(advanced_index_positions)): + if not advanced_index_positions or not _is_contiguous(advanced_index_positions): # Nothing to reorder: dimensions on the indexing result are already # ordered like vindex. See NumPy's rule for "Combining advanced and # basic indexing": @@ -137,8 +137,7 @@ def __getitem__(self, key): def __setitem__(self, key, value): """Value must have dimensionality matching the key.""" mixed_positions, vindex_positions = _advanced_indexer_subspaces(key) - self._array[key] = np.moveaxis(value, vindex_positions, - mixed_positions) + self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) def rolling_window(a, axis, window, center, fill_value): @@ -150,7 +149,7 @@ def rolling_window(a, axis, window, center, fill_value): pads[axis] = (start, end) else: pads[axis] = (window - 1, 0) - a = np.pad(a, pads, mode='constant', constant_values=fill_value) + a = np.pad(a, pads, mode="constant", constant_values=fill_value) return _rolling_window(a, window, axis) @@ -191,30 +190,33 @@ def _rolling_window(a, window, axis=-1): a = np.swapaxes(a, axis, -1) if window < 1: - raise ValueError( - "`window` must be at least 1. Given : {}".format(window)) + raise ValueError("`window` must be at least 1. Given : {}".format(window)) if window > a.shape[-1]: raise ValueError("`window` is too long. Given : {}".format(window)) shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) strides = a.strides + (a.strides[-1],) - rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides, - writeable=False) + rolling = np.lib.stride_tricks.as_strided( + a, shape=shape, strides=strides, writeable=False + ) return np.swapaxes(rolling, -2, axis) def _create_bottleneck_method(name, npmodule=np): def f(values, axis=None, **kwargs): - dtype = kwargs.get('dtype', None) + dtype = kwargs.get("dtype", None) bn_func = getattr(bn, name, None) - if (_USE_BOTTLENECK and bn_func is not None and - not isinstance(axis, tuple) and - values.dtype.kind in 'uifc' and - values.dtype.isnative and - (dtype is None or np.dtype(dtype) == values.dtype)): + if ( + _USE_BOTTLENECK + and bn_func is not None + and not isinstance(axis, tuple) + and values.dtype.kind in "uifc" + and values.dtype.isnative + and (dtype is None or np.dtype(dtype) == values.dtype) + ): # bottleneck does not take care dtype, min_count - kwargs.pop('dtype', None) + kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) else: result = getattr(npmodule, name)(values, axis=axis, **kwargs) @@ -225,12 +227,12 @@ def f(values, axis=None, **kwargs): return f -nanmin = _create_bottleneck_method('nanmin') -nanmax = _create_bottleneck_method('nanmax') -nanmean = _create_bottleneck_method('nanmean') -nanmedian = _create_bottleneck_method('nanmedian') -nanvar = _create_bottleneck_method('nanvar') -nanstd = _create_bottleneck_method('nanstd') -nanprod = _create_bottleneck_method('nanprod') -nancumsum = _create_bottleneck_method('nancumsum') -nancumprod = _create_bottleneck_method('nancumprod') +nanmin = _create_bottleneck_method("nanmin") +nanmax = _create_bottleneck_method("nanmax") +nanmean = _create_bottleneck_method("nanmean") +nanmedian = _create_bottleneck_method("nanmedian") +nanvar = _create_bottleneck_method("nanvar") +nanstd = _create_bottleneck_method("nanstd") +nanprod = _create_bottleneck_method("nanprod") +nancumsum = _create_bottleneck_method("nancumsum") +nancumprod = _create_bottleneck_method("nancumprod") diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 0c0fc1e50a8..78c4466faed 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -14,6 +14,7 @@ try: import bottleneck as bn + has_bottleneck = True except ImportError: # use numpy methods instead @@ -21,23 +22,43 @@ has_bottleneck = False -UNARY_OPS = ['neg', 'pos', 'abs', 'invert'] -CMP_BINARY_OPS = ['lt', 'le', 'ge', 'gt'] -NUM_BINARY_OPS = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', - 'pow', 'and', 'xor', 'or'] +UNARY_OPS = ["neg", "pos", "abs", "invert"] +CMP_BINARY_OPS = ["lt", "le", "ge", "gt"] +NUM_BINARY_OPS = [ + "add", + "sub", + "mul", + "truediv", + "floordiv", + "mod", + "pow", + "and", + "xor", + "or", +] # methods which pass on the numpy return value unchanged # be careful not to list methods that we would want to wrap later -NUMPY_SAME_METHODS = ['item', 'searchsorted'] +NUMPY_SAME_METHODS = ["item", "searchsorted"] # methods which don't modify the data shape, so the result should still be # wrapped in an Variable/DataArray -NUMPY_UNARY_METHODS = ['astype', 'argsort', 'clip', 'conj', 'conjugate'] -PANDAS_UNARY_FUNCTIONS = ['isnull', 'notnull'] +NUMPY_UNARY_METHODS = ["astype", "argsort", "clip", "conj", "conjugate"] +PANDAS_UNARY_FUNCTIONS = ["isnull", "notnull"] # methods which remove an axis -REDUCE_METHODS = ['all', 'any'] -NAN_REDUCE_METHODS = ['argmax', 'argmin', 'max', 'min', 'mean', 'prod', 'sum', - 'std', 'var', 'median'] -NAN_CUM_METHODS = ['cumsum', 'cumprod'] +REDUCE_METHODS = ["all", "any"] +NAN_REDUCE_METHODS = [ + "argmax", + "argmin", + "max", + "min", + "mean", + "prod", + "sum", + "std", + "var", + "median", +] +NAN_CUM_METHODS = ["cumsum", "cumprod"] # TODO: wrap take, dot, sort @@ -138,12 +159,16 @@ def fillna(data, other, join="left", dataset_join="left"): """ from .computation import apply_ufunc - return apply_ufunc(duck_array_ops.fillna, data, other, - join=join, - dask="allowed", - dataset_join=dataset_join, - dataset_fill_value=np.nan, - keep_attrs=True) + return apply_ufunc( + duck_array_ops.fillna, + data, + other, + join=join, + dask="allowed", + dataset_join=dataset_join, + dataset_fill_value=np.nan, + keep_attrs=True, + ) def where_method(self, cond, other=dtypes.NA): @@ -162,14 +187,19 @@ def where_method(self, cond, other=dtypes.NA): Same type as caller. """ from .computation import apply_ufunc + # alignment for three arguments is complicated, so don't support it yet - join = 'inner' if other is dtypes.NA else 'exact' - return apply_ufunc(duck_array_ops.where_method, - self, cond, other, - join=join, - dataset_join=join, - dask='allowed', - keep_attrs=True) + join = "inner" if other is dtypes.NA else "exact" + return apply_ufunc( + duck_array_ops.where_method, + self, + cond, + other, + join=join, + dataset_join=join, + dask="allowed", + keep_attrs=True, + ) def _call_possibly_missing_method(arg, name, args, kwargs): @@ -177,7 +207,7 @@ def _call_possibly_missing_method(arg, name, args, kwargs): method = getattr(arg, name) except AttributeError: duck_array_ops.fail_on_dask_array_input(arg, func_name=name) - if hasattr(arg, 'data'): + if hasattr(arg, "data"): duck_array_ops.fail_on_dask_array_input(arg.data, func_name=name) raise else: @@ -187,6 +217,7 @@ def _call_possibly_missing_method(arg, name, args, kwargs): def _values_method_wrapper(name): def func(self, *args, **kwargs): return _call_possibly_missing_method(self.data, name, args, kwargs) + func.__name__ = name func.__doc__ = getattr(np.ndarray, name).__doc__ return func @@ -195,6 +226,7 @@ def func(self, *args, **kwargs): def _method_wrapper(name): def func(self, *args, **kwargs): return _call_possibly_missing_method(self, name, args, kwargs) + func.__name__ = name func.__doc__ = getattr(np.ndarray, name).__doc__ return func @@ -212,56 +244,60 @@ def func(self, *args, **kwargs): return getattr(self, name)(*args, **kwargs) except AttributeError: return f(self, *args, **kwargs) + func.__name__ = name func.__doc__ = f.__doc__ return func def inject_reduce_methods(cls): - methods = ([(name, getattr(duck_array_ops, 'array_%s' % name), False) - for name in REDUCE_METHODS] + - [(name, getattr(duck_array_ops, name), True) - for name in NAN_REDUCE_METHODS] + - [('count', duck_array_ops.count, False)]) + methods = ( + [ + (name, getattr(duck_array_ops, "array_%s" % name), False) + for name in REDUCE_METHODS + ] + + [(name, getattr(duck_array_ops, name), True) for name in NAN_REDUCE_METHODS] + + [("count", duck_array_ops.count, False)] + ) for name, f, include_skipna in methods: - numeric_only = getattr(f, 'numeric_only', False) - available_min_count = getattr(f, 'available_min_count', False) - min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else '' + numeric_only = getattr(f, "numeric_only", False) + available_min_count = getattr(f, "available_min_count", False) + min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else "" func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format( - name=name, cls=cls.__name__, + name=name, + cls=cls.__name__, extra_args=cls._reduce_extra_args_docstring.format(name=name), - min_count_docs=min_count_docs) + min_count_docs=min_count_docs, + ) setattr(cls, name, func) def inject_cum_methods(cls): - methods = ([(name, getattr(duck_array_ops, name), True) - for name in NAN_CUM_METHODS]) + methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS] for name, f, include_skipna in methods: - numeric_only = getattr(f, 'numeric_only', False) + numeric_only = getattr(f, "numeric_only", False) func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name func.__doc__ = _CUM_DOCSTRING_TEMPLATE.format( - name=name, cls=cls.__name__, - extra_args=cls._cum_extra_args_docstring.format(name=name)) + name=name, + cls=cls.__name__, + extra_args=cls._cum_extra_args_docstring.format(name=name), + ) setattr(cls, name, func) def op_str(name): - return '__%s__' % name + return "__%s__" % name def get_op(name): return getattr(operator, op_str(name)) -NON_INPLACE_OP = { - get_op('i' + name): get_op(name) - for name in NUM_BINARY_OPS -} +NON_INPLACE_OP = {get_op("i" + name): get_op(name) for name in NUM_BINARY_OPS} def inplace_to_noninplace_op(f): @@ -272,16 +308,14 @@ def inject_binary_ops(cls, inplace=False): for name in CMP_BINARY_OPS + NUM_BINARY_OPS: setattr(cls, op_str(name), cls._binary_op(get_op(name))) - for name, f in [('eq', array_eq), ('ne', array_ne)]: + for name, f in [("eq", array_eq), ("ne", array_ne)]: setattr(cls, op_str(name), cls._binary_op(f)) for name in NUM_BINARY_OPS: # only numeric operations have in-place and reflexive variants - setattr(cls, op_str('r' + name), - cls._binary_op(get_op(name), reflexive=True)) + setattr(cls, op_str("r" + name), cls._binary_op(get_op(name), reflexive=True)) if inplace: - setattr(cls, op_str('i' + name), - cls._inplace_binary_op(get_op('i' + name))) + setattr(cls, op_str("i" + name), cls._inplace_binary_op(get_op("i" + name))) def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): @@ -299,12 +333,11 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): setattr(cls, name, cls._unary_op(_method_wrapper(name))) for name in PANDAS_UNARY_FUNCTIONS: - f = _func_slash_method_wrapper( - getattr(duck_array_ops, name), name=name) + f = _func_slash_method_wrapper(getattr(duck_array_ops, name), name=name) setattr(cls, name, cls._unary_op(f)) - f = _func_slash_method_wrapper(duck_array_ops.around, name='round') - setattr(cls, 'round', cls._unary_op(f)) + f = _func_slash_method_wrapper(duck_array_ops.around, name="round") + setattr(cls, "round", cls._unary_op(f)) if array_only: # these methods don't return arrays of the same shape as the input, so @@ -318,11 +351,9 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): def inject_coarsen_methods(cls): # standard numpy reduce methods - methods = [(name, getattr(duck_array_ops, name)) - for name in NAN_REDUCE_METHODS] + methods = [(name, getattr(duck_array_ops, name)) for name in NAN_REDUCE_METHODS] for name, f in methods: func = cls._reduce_method(f) func.__name__ = name - func.__doc__ = _COARSEN_REDUCE_DOCSTRING_TEMPLATE.format( - name=func.__name__) + func.__doc__ = _COARSEN_REDUCE_DOCSTRING_TEMPLATE.format(name=func.__name__) setattr(cls, name, func) diff --git a/xarray/core/options.py b/xarray/core/options.py index 532d86a8f38..c5086268f48 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,27 +1,27 @@ import warnings -DISPLAY_WIDTH = 'display_width' -ARITHMETIC_JOIN = 'arithmetic_join' -ENABLE_CFTIMEINDEX = 'enable_cftimeindex' -FILE_CACHE_MAXSIZE = 'file_cache_maxsize' -WARN_FOR_UNCLOSED_FILES = 'warn_for_unclosed_files' -CMAP_SEQUENTIAL = 'cmap_sequential' -CMAP_DIVERGENT = 'cmap_divergent' -KEEP_ATTRS = 'keep_attrs' +DISPLAY_WIDTH = "display_width" +ARITHMETIC_JOIN = "arithmetic_join" +ENABLE_CFTIMEINDEX = "enable_cftimeindex" +FILE_CACHE_MAXSIZE = "file_cache_maxsize" +WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files" +CMAP_SEQUENTIAL = "cmap_sequential" +CMAP_DIVERGENT = "cmap_divergent" +KEEP_ATTRS = "keep_attrs" OPTIONS = { DISPLAY_WIDTH: 80, - ARITHMETIC_JOIN: 'inner', + ARITHMETIC_JOIN: "inner", ENABLE_CFTIMEINDEX: True, FILE_CACHE_MAXSIZE: 128, WARN_FOR_UNCLOSED_FILES: False, - CMAP_SEQUENTIAL: 'viridis', - CMAP_DIVERGENT: 'RdBu_r', - KEEP_ATTRS: 'default' + CMAP_SEQUENTIAL: "viridis", + CMAP_DIVERGENT: "RdBu_r", + KEEP_ATTRS: "default", } -_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) +_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) def _positive_integer(value): @@ -34,39 +34,41 @@ def _positive_integer(value): ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), FILE_CACHE_MAXSIZE: _positive_integer, WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), - KEEP_ATTRS: lambda choice: choice in [True, False, 'default'] + KEEP_ATTRS: lambda choice: choice in [True, False, "default"], } def _set_file_cache_maxsize(value): from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): warnings.warn( - 'The enable_cftimeindex option is now a no-op ' - 'and will be removed in a future version of xarray.', - FutureWarning) + "The enable_cftimeindex option is now a no-op " + "and will be removed in a future version of xarray.", + FutureWarning, + ) _SETTERS = { FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, - ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex + ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex, } def _get_keep_attrs(default): - global_choice = OPTIONS['keep_attrs'] + global_choice = OPTIONS["keep_attrs"] - if global_choice == 'default': + if global_choice == "default": return default elif global_choice in [True, False]: return global_choice else: raise ValueError( - "The global option keep_attrs must be one of" - " True, False or 'default'.") + "The global option keep_attrs must be one of" " True, False or 'default'." + ) class set_options: @@ -119,11 +121,11 @@ def __init__(self, **kwargs): for k, v in kwargs.items(): if k not in OPTIONS: raise ValueError( - 'argument name %r is not in the set of valid options %r' - % (k, set(OPTIONS))) + "argument name %r is not in the set of valid options %r" + % (k, set(OPTIONS)) + ) if k in _VALIDATORS and not _VALIDATORS[k](v): - raise ValueError( - 'option %r given an invalid value: %r' % (k, v)) + raise ValueError("option %r given an invalid value: %r" % (k, v)) self.old[k] = OPTIONS[k] self._apply_update(kwargs) diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index 5b7d60a5329..654a43b505e 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -43,9 +43,10 @@ # allow ourselves to type checks for Panel even after it's removed -if LooseVersion(pd.__version__) < '0.25.0': +if LooseVersion(pd.__version__) < "0.25.0": Panel = pd.Panel else: + class Panel: # type: ignore pass diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 259f44f2862..aaf52b9f295 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,10 +1,11 @@ import numpy as np -integer_types = (int, np.integer, ) +integer_types = (int, np.integer) try: # solely for isinstance checks import dask.array + dask_array_type = (dask.array.Array,) except ImportError: # pragma: no cover dask_array_type = () @@ -12,6 +13,7 @@ try: # solely for isinstance checks import sparse + sparse_array_type = (sparse.SparseArray,) except ImportError: # pragma: no cover sparse_array_type = () diff --git a/xarray/core/resample.py b/xarray/core/resample.py index b02bd11fe1a..de70ebb6950 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,7 +1,7 @@ from . import ops from .groupby import DEFAULT_DIMS, DataArrayGroupBy, DatasetGroupBy -RESAMPLE_DIM = '__resample_dim__' +RESAMPLE_DIM = "__resample_dim__" class Resample: @@ -49,27 +49,28 @@ def _upsample(self, method, *args, **kwargs): if self._dim in v.dims: self._obj = self._obj.drop(k) - if method == 'asfreq': + if method == "asfreq": return self.mean(self._dim) - elif method in ['pad', 'ffill', 'backfill', 'bfill', 'nearest']: + elif method in ["pad", "ffill", "backfill", "bfill", "nearest"]: kwargs = kwargs.copy() kwargs.update(**{self._dim: upsampled_index}) return self._obj.reindex(method=method, *args, **kwargs) - elif method == 'interpolate': + elif method == "interpolate": return self._interpolate(*args, **kwargs) else: - raise ValueError('Specified method was "{}" but must be one of' - '"asfreq", "ffill", "bfill", or "interpolate"' - .format(method)) + raise ValueError( + 'Specified method was "{}" but must be one of' + '"asfreq", "ffill", "bfill", or "interpolate"'.format(method) + ) def asfreq(self): """Return values of original object at the new up-sampling frequency; essentially a re-index with new times set to NaN. """ - return self._upsample('asfreq') + return self._upsample("asfreq") def pad(self, tolerance=None): """Forward fill new values at up-sampled frequency. @@ -84,7 +85,8 @@ def pad(self, tolerance=None): new values. Data with indices that are outside the given tolerance are filled with ``NaN`` s """ - return self._upsample('pad', tolerance=tolerance) + return self._upsample("pad", tolerance=tolerance) + ffill = pad def backfill(self, tolerance=None): @@ -100,7 +102,8 @@ def backfill(self, tolerance=None): new values. Data with indices that are outside the given tolerance are filled with ``NaN`` s """ - return self._upsample('backfill', tolerance=tolerance) + return self._upsample("backfill", tolerance=tolerance) + bfill = backfill def nearest(self, tolerance=None): @@ -117,9 +120,9 @@ def nearest(self, tolerance=None): new values. Data with indices that are outside the given tolerance are filled with ``NaN`` s """ - return self._upsample('nearest', tolerance=tolerance) + return self._upsample("nearest", tolerance=tolerance) - def interpolate(self, kind='linear'): + def interpolate(self, kind="linear"): """Interpolate up-sampled data using the original data as knots. @@ -136,7 +139,7 @@ def interpolate(self, kind='linear'): """ return self._interpolate(kind=kind) - def _interpolate(self, kind='linear'): + def _interpolate(self, kind="linear"): """Apply scipy.interpolate.interp1d along resampling dimension.""" # drop any existing non-dimension coordinates along the resampling # dimension @@ -144,9 +147,12 @@ def _interpolate(self, kind='linear'): for k, v in self._obj.coords.items(): if k != self._dim and self._dim in v.dims: dummy = dummy.drop(k) - return dummy.interp(assume_sorted=True, method=kind, - kwargs={'bounds_error': False}, - **{self._dim: self._full_index}) + return dummy.interp( + assume_sorted=True, + method=kind, + kwargs={"bounds_error": False}, + **{self._dim: self._full_index} + ) class DataArrayResample(DataArrayGroupBy, Resample): @@ -157,9 +163,11 @@ class DataArrayResample(DataArrayGroupBy, Resample): def __init__(self, *args, dim=None, resample_dim=None, **kwargs): if dim == resample_dim: - raise ValueError("Proxy resampling dimension ('{}') " - "cannot have the same name as actual dimension " - "('{}')! ".format(resample_dim, dim)) + raise ValueError( + "Proxy resampling dimension ('{}') " + "cannot have the same name as actual dimension " + "('{}')! ".format(resample_dim, dim) + ) self._dim = dim self._resample_dim = resample_dim @@ -204,8 +212,7 @@ def apply(self, func, shortcut=False, args=(), **kwargs): applied : DataArray or DataArray The result of splitting, applying and combining this array. """ - combined = super().apply( - func, shortcut=shortcut, args=args, **kwargs) + combined = super().apply(func, shortcut=shortcut, args=args, **kwargs) # If the aggregation function didn't drop the original resampling # dimension, then we need to do so before we can rename the proxy @@ -230,9 +237,11 @@ class DatasetResample(DatasetGroupBy, Resample): def __init__(self, *args, dim=None, resample_dim=None, **kwargs): if dim == resample_dim: - raise ValueError("Proxy resampling dimension ('{}') " - "cannot have the same name as actual dimension " - "('{}')! ".format(resample_dim, dim)) + raise ValueError( + "Proxy resampling dimension ('{}') " + "cannot have the same name as actual dimension " + "('{}')! ".format(resample_dim, dim) + ) self._dim = dim self._resample_dim = resample_dim @@ -301,8 +310,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): if dim == DEFAULT_DIMS: dim = None - return super().reduce( - func, dim, keep_attrs, **kwargs) + return super().reduce(func, dim, keep_attrs, **kwargs) ops.inject_reduce_methods(DatasetResample) diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index cac78aabe98..cfac224363d 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -42,8 +42,15 @@ import pandas as pd from ..coding.cftime_offsets import ( - CFTIME_TICKS, Day, MonthEnd, QuarterEnd, YearEnd, cftime_range, - normalize_date, to_offset) + CFTIME_TICKS, + Day, + MonthEnd, + QuarterEnd, + YearEnd, + cftime_range, + normalize_date, + to_offset, +) from ..coding.cftimeindex import CFTimeIndex @@ -61,14 +68,14 @@ def __init__(self, freq, closed=None, label=None, base=0, loffset=None): if isinstance(self.freq, (MonthEnd, QuarterEnd, YearEnd)): if self.closed is None: - self.closed = 'right' + self.closed = "right" if self.label is None: - self.label = 'right' + self.label = "right" else: if self.closed is None: - self.closed = 'left' + self.closed = "left" if self.label is None: - self.label = 'left' + self.label = "left" def first_items(self, index): """Meant to reproduce the results of the following @@ -80,8 +87,9 @@ def first_items(self, index): with index being a CFTimeIndex instead of a DatetimeIndex. """ - datetime_bins, labels = _get_time_bins(index, self.freq, self.closed, - self.label, self.base) + datetime_bins, labels = _get_time_bins( + index, self.freq, self.closed, self.label, self.base + ) if self.loffset is not None: if isinstance(self.loffset, datetime.timedelta): labels = labels + self.loffset @@ -94,12 +102,11 @@ def first_items(self, index): if index[-1] > datetime_bins[-1]: raise ValueError("Value falls after last bin") - integer_bins = np.searchsorted( - index, datetime_bins, side=self.closed)[:-1] + integer_bins = np.searchsorted(index, datetime_bins, side=self.closed)[:-1] first_items = pd.Series(integer_bins, labels) # Mask duplicate values with NaNs, preserving the last values - non_duplicate = ~first_items.duplicated('last') + non_duplicate = ~first_items.duplicated("last") return first_items.where(non_duplicate) @@ -137,24 +144,26 @@ def _get_time_bins(index, freq, closed, label, base): """ if not isinstance(index, CFTimeIndex): - raise TypeError('index must be a CFTimeIndex, but got ' - 'an instance of %r' % type(index).__name__) + raise TypeError( + "index must be a CFTimeIndex, but got " + "an instance of %r" % type(index).__name__ + ) if len(index) == 0: datetime_bins = labels = CFTimeIndex(data=[], name=index.name) return datetime_bins, labels - first, last = _get_range_edges(index.min(), index.max(), freq, - closed=closed, - base=base) - datetime_bins = labels = cftime_range(freq=freq, - start=first, - end=last, - name=index.name) + first, last = _get_range_edges( + index.min(), index.max(), freq, closed=closed, base=base + ) + datetime_bins = labels = cftime_range( + freq=freq, start=first, end=last, name=index.name + ) - datetime_bins, labels = _adjust_bin_edges(datetime_bins, freq, closed, - index, labels) + datetime_bins, labels = _adjust_bin_edges( + datetime_bins, freq, closed, index, labels + ) - if label == 'right': + if label == "right": labels = labels[1:] else: labels = labels[:-1] @@ -201,12 +210,12 @@ def _adjust_bin_edges(datetime_bins, offset, closed, index, labels): This is also required for daily frequencies longer than one day and year-end frequencies. """ - is_super_daily = (isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)) or - (isinstance(offset, Day) and offset.n > 1)) + is_super_daily = isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)) or ( + isinstance(offset, Day) and offset.n > 1 + ) if is_super_daily: - if closed == 'right': - datetime_bins = datetime_bins + datetime.timedelta(days=1, - microseconds=-1) + if closed == "right": + datetime_bins = datetime_bins + datetime.timedelta(days=1, microseconds=-1) if datetime_bins[-2] > index.max(): datetime_bins = datetime_bins[:-1] labels = labels[:-1] @@ -214,7 +223,7 @@ def _adjust_bin_edges(datetime_bins, offset, closed, index, labels): return datetime_bins, labels -def _get_range_edges(first, last, offset, closed='left', base=0): +def _get_range_edges(first, last, offset, closed="left", base=0): """ Get the correct starting and ending datetimes for the resampled CFTimeIndex range. @@ -245,14 +254,15 @@ def _get_range_edges(first, last, offset, closed='left', base=0): Corrected ending datetime object for resampled CFTimeIndex range. """ if isinstance(offset, CFTIME_TICKS): - first, last = _adjust_dates_anchored(first, last, offset, - closed=closed, base=base) + first, last = _adjust_dates_anchored( + first, last, offset, closed=closed, base=base + ) return first, last else: first = normalize_date(first) last = normalize_date(last) - if closed == 'left': + if closed == "left": first = offset.rollback(first) else: first = first - offset @@ -261,7 +271,7 @@ def _get_range_edges(first, last, offset, closed='left', base=0): return first, last -def _adjust_dates_anchored(first, last, offset, closed='right', base=0): +def _adjust_dates_anchored(first, last, offset, closed="right", base=0): """ First and last offsets should be calculated from the start day to fix an error cause by resampling across multiple days when a one day period is not a multiple of the frequency. @@ -298,11 +308,9 @@ def _adjust_dates_anchored(first, last, offset, closed='right', base=0): start_day = normalize_date(first) base_td = type(offset)(n=base).as_timedelta() start_day += base_td - foffset = exact_cftime_datetime_difference( - start_day, first) % offset.as_timedelta() - loffset = exact_cftime_datetime_difference( - start_day, last) % offset.as_timedelta() - if closed == 'right': + foffset = exact_cftime_datetime_difference(start_day, first) % offset.as_timedelta() + loffset = exact_cftime_datetime_difference(start_day, last) % offset.as_timedelta() + if closed == "right": if foffset.total_seconds() > 0: fresult = first - foffset else: diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 3bfe1eaf49e..592cae9007e 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -43,7 +43,7 @@ class Rolling: DataArray.rolling """ - _attributes = ['window', 'min_periods', 'center', 'dim'] + _attributes = ["window", "min_periods", "center", "dim"] def __init__(self, obj, windows, min_periods=None, center=False): """ @@ -71,20 +71,23 @@ def __init__(self, obj, windows, min_periods=None, center=False): rolling : type of input argument """ - if (bottleneck is not None and - (LooseVersion(bottleneck.__version__) < LooseVersion('1.0'))): - warnings.warn('xarray requires bottleneck version of 1.0 or ' - 'greater for rolling operations. Rolling ' - 'aggregation methods will use numpy instead' - 'of bottleneck.') + if bottleneck is not None and ( + LooseVersion(bottleneck.__version__) < LooseVersion("1.0") + ): + warnings.warn( + "xarray requires bottleneck version of 1.0 or " + "greater for rolling operations. Rolling " + "aggregation methods will use numpy instead" + "of bottleneck." + ) if len(windows) != 1: - raise ValueError('exactly one dim/window should be provided') + raise ValueError("exactly one dim/window should be provided") dim, window = next(iter(windows.items())) if window <= 0: - raise ValueError('window must be > 0') + raise ValueError("window must be > 0") self.obj = obj @@ -95,8 +98,7 @@ def __init__(self, obj, windows, min_periods=None, center=False): self._min_periods = window else: if min_periods <= 0: - raise ValueError( - 'min_periods must be greater than zero or None') + raise ValueError("min_periods must be greater than zero or None") self._min_periods = min_periods self.center = center @@ -105,42 +107,48 @@ def __init__(self, obj, windows, min_periods=None, center=False): def __repr__(self): """provide a nice str repr of our rolling object""" - attrs = ["{k}->{v}".format(k=k, v=getattr(self, k)) - for k in self._attributes - if getattr(self, k, None) is not None] - return "{klass} [{attrs}]".format(klass=self.__class__.__name__, - attrs=','.join(attrs)) + attrs = [ + "{k}->{v}".format(k=k, v=getattr(self, k)) + for k in self._attributes + if getattr(self, k, None) is not None + ] + return "{klass} [{attrs}]".format( + klass=self.__class__.__name__, attrs=",".join(attrs) + ) def __len__(self): return self.obj.sizes[self.dim] def _reduce_method(name): array_agg_func = getattr(duck_array_ops, name) - bottleneck_move_func = getattr(bottleneck, 'move_' + name, None) + bottleneck_move_func = getattr(bottleneck, "move_" + name, None) def method(self, **kwargs): return self._numpy_or_bottleneck_reduce( - array_agg_func, bottleneck_move_func, **kwargs) + array_agg_func, bottleneck_move_func, **kwargs + ) + method.__name__ = name method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name) return method - argmax = _reduce_method('argmax') - argmin = _reduce_method('argmin') - max = _reduce_method('max') - min = _reduce_method('min') - mean = _reduce_method('mean') - prod = _reduce_method('prod') - sum = _reduce_method('sum') - std = _reduce_method('std') - var = _reduce_method('var') - median = _reduce_method('median') + argmax = _reduce_method("argmax") + argmin = _reduce_method("argmin") + max = _reduce_method("max") + min = _reduce_method("min") + mean = _reduce_method("mean") + prod = _reduce_method("prod") + sum = _reduce_method("sum") + std = _reduce_method("std") + var = _reduce_method("var") + median = _reduce_method("median") def count(self): rolling_count = self._counts() enough_periods = rolling_count >= self._min_periods return rolling_count.where(enough_periods) - count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name='count') + + count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") class DataArrayRolling(Rolling): @@ -178,15 +186,14 @@ def __init__(self, obj, windows, min_periods=None, center=False): Dataset.rolling Dataset.groupby """ - super().__init__( - obj, windows, min_periods=min_periods, center=center) + super().__init__(obj, windows, min_periods=min_periods, center=center) self.window_labels = self.obj[self.dim] def __iter__(self): stops = np.arange(1, len(self.window_labels) + 1) starts = stops - int(self.window) - starts[:int(self.window)] = 0 + starts[: int(self.window)] = 0 for (label, start, stop) in zip(self.window_labels, starts, stops): window = self.obj.isel(**{self.dim: slice(start, stop)}) @@ -235,11 +242,12 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): from .dataarray import DataArray - window = self.obj.variable.rolling_window(self.dim, self.window, - window_dim, self.center, - fill_value=fill_value) - result = DataArray(window, dims=self.obj.dims + (window_dim,), - coords=self.obj.coords) + window = self.obj.variable.rolling_window( + self.dim, self.window, window_dim, self.center, fill_value=fill_value + ) + result = DataArray( + window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords + ) return result.isel(**{self.dim: slice(None, None, stride)}) def reduce(self, func, **kwargs): @@ -283,7 +291,7 @@ def reduce(self, func, **kwargs): array([[ 0., 1., 3., 6.], [ 4., 9., 15., 18.]]) """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, '_rolling_dim') + rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") windows = self.construct(rolling_dim) result = windows.reduce(func, dim=rolling_dim, **kwargs) @@ -294,15 +302,17 @@ def reduce(self, func, **kwargs): def _counts(self): """ Number of non-nan entries in each rolling window. """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, '_rolling_dim') + rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") # We use False as the fill_value instead of np.nan, since boolean # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. - counts = (self.obj.notnull() - .rolling(center=self.center, **{self.dim: self.window}) - .construct(rolling_dim, fill_value=False) - .sum(dim=rolling_dim, skipna=False)) + counts = ( + self.obj.notnull() + .rolling(center=self.center, **{self.dim: self.window}) + .construct(rolling_dim, fill_value=False) + .sum(dim=rolling_dim, skipna=False) + ) return counts def _bottleneck_reduce(self, func, **kwargs): @@ -319,8 +329,10 @@ def _bottleneck_reduce(self, func, **kwargs): padded = self.obj.variable if self.center: - if (LooseVersion(np.__version__) < LooseVersion('1.13') and - self.obj.dtype.kind == 'b'): + if ( + LooseVersion(np.__version__) < LooseVersion("1.13") + and self.obj.dtype.kind == "b" + ): # with numpy < 1.13 bottleneck cannot handle np.nan-Boolean # mixed array correctly. We cast boolean array to float. padded = padded.astype(float) @@ -328,24 +340,25 @@ def _bottleneck_reduce(self, func, **kwargs): if isinstance(padded.data, dask_array_type): # Workaround to make the padded chunk size is larger than # self.window-1 - shift = - (self.window + 1) // 2 + shift = -(self.window + 1) // 2 offset = (self.window - 1) // 2 - valid = (slice(None), ) * axis + ( - slice(offset, offset + self.obj.shape[axis]), ) + valid = (slice(None),) * axis + ( + slice(offset, offset + self.obj.shape[axis]), + ) else: shift = (-self.window // 2) + 1 - valid = (slice(None), ) * axis + (slice(-shift, None), ) + valid = (slice(None),) * axis + (slice(-shift, None),) padded = padded.pad_with_fill_value({self.dim: (0, -shift)}) if isinstance(padded.data, dask_array_type): - raise AssertionError('should not be reachable') - values = dask_rolling_wrapper(func, padded.data, - window=self.window, - min_count=min_count, - axis=axis) + raise AssertionError("should not be reachable") + values = dask_rolling_wrapper( + func, padded.data, window=self.window, min_count=min_count, axis=axis + ) else: - values = func(padded.data, window=self.window, - min_count=min_count, axis=axis) + values = func( + padded.data, window=self.window, min_count=min_count, axis=axis + ) if self.center: values = values[valid] @@ -356,8 +369,9 @@ def _bottleneck_reduce(self, func, **kwargs): def _numpy_or_bottleneck_reduce( self, array_agg_func, bottleneck_move_func, **kwargs ): - if (bottleneck_move_func is not None and - not isinstance(self.obj.data, dask_array_type)): + if bottleneck_move_func is not None and not isinstance( + self.obj.data, dask_array_type + ): # TODO: renable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are # fixed. @@ -409,11 +423,11 @@ def __init__(self, obj, windows, min_periods=None, center=False): for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim if self.dim in da.dims: - self.rollings[key] = DataArrayRolling( - da, windows, min_periods, center) + self.rollings[key] = DataArrayRolling(da, windows, min_periods, center) def _dataset_implementation(self, func, **kwargs): from .dataset import Dataset + reduced = OrderedDict() for key, da in self.obj.data_vars.items(): if self.dim in da.dims: @@ -441,7 +455,8 @@ def reduce(self, func, **kwargs): Array with summarized data. """ return self._dataset_implementation( - functools.partial(DataArrayRolling.reduce, func=func), **kwargs) + functools.partial(DataArrayRolling.reduce, func=func), **kwargs + ) def _counts(self): return self._dataset_implementation(DataArrayRolling._counts) @@ -450,10 +465,13 @@ def _numpy_or_bottleneck_reduce( self, array_agg_func, bottleneck_move_func, **kwargs ): return self._dataset_implementation( - functools.partial(DataArrayRolling._numpy_or_bottleneck_reduce, - array_agg_func=array_agg_func, - bottleneck_move_func=bottleneck_move_func), - **kwargs) + functools.partial( + DataArrayRolling._numpy_or_bottleneck_reduce, + array_agg_func=array_agg_func, + bottleneck_move_func=bottleneck_move_func, + ), + **kwargs + ) def construct(self, window_dim, stride=1, fill_value=dtypes.NA): """ @@ -480,11 +498,13 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): for key, da in self.obj.data_vars.items(): if self.dim in da.dims: dataset[key] = self.rollings[key].construct( - window_dim, fill_value=fill_value) + window_dim, fill_value=fill_value + ) else: dataset[key] = da return Dataset(dataset, coords=self.obj.coords).isel( - **{self.dim: slice(None, None, stride)}) + **{self.dim: slice(None, None, stride)} + ) class Coarsen: @@ -496,7 +516,7 @@ class Coarsen: DataArray.coarsen """ - _attributes = ['windows', 'side', 'trim_excess'] + _attributes = ["windows", "side", "trim_excess"] def __init__(self, obj, windows, boundary, side, coord_func): """ @@ -538,11 +558,14 @@ def __init__(self, obj, windows, boundary, side, coord_func): def __repr__(self): """provide a nice str repr of our coarsen object""" - attrs = ["{k}->{v}".format(k=k, v=getattr(self, k)) - for k in self._attributes - if getattr(self, k, None) is not None] - return "{klass} [{attrs}]".format(klass=self.__class__.__name__, - attrs=','.join(attrs)) + attrs = [ + "{k}->{v}".format(k=k, v=getattr(self, k)) + for k in self._attributes + if getattr(self, k, None) is not None + ] + return "{klass} [{attrs}]".format( + klass=self.__class__.__name__, attrs=",".join(attrs) + ) class DataArrayCoarsen(Coarsen): @@ -552,11 +575,13 @@ def _reduce_method(cls, func): Return a wrapped function for injecting numpy methods. see ops.inject_coarsen_methods """ + def wrapped_func(self, **kwargs): from .dataarray import DataArray reduced = self.obj.variable.coarsen( - self.windows, func, self.boundary, self.side) + self.windows, func, self.boundary, self.side + ) coords = {} for c, v in self.obj.coords.items(): if c == self.obj.name: @@ -564,8 +589,8 @@ def wrapped_func(self, **kwargs): else: if any(d in self.windows for d in v.dims): coords[c] = v.variable.coarsen( - self.windows, self.coord_func[c], - self.boundary, self.side) + self.windows, self.coord_func[c], self.boundary, self.side + ) else: coords[c] = v return DataArray(reduced, dims=self.obj.dims, coords=coords) @@ -580,20 +605,22 @@ def _reduce_method(cls, func): Return a wrapped function for injecting numpy methods. see ops.inject_coarsen_methods """ + def wrapped_func(self, **kwargs): from .dataset import Dataset reduced = OrderedDict() for key, da in self.obj.data_vars.items(): reduced[key] = da.variable.coarsen( - self.windows, func, self.boundary, self.side) + self.windows, func, self.boundary, self.side + ) coords = {} for c, v in self.obj.coords.items(): if any(d in self.windows for d in v.dims): coords[c] = v.variable.coarsen( - self.windows, self.coord_func[c], - self.boundary, self.side) + self.windows, self.coord_func[c], self.boundary, self.side + ) else: coords[c] = v.variable return Dataset(reduced, coords=coords) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index ff6baef5c3a..057884fef85 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -15,11 +15,11 @@ def move_exp_nanmean(array, *, axis, alpha): if isinstance(array, dask_array_type): raise TypeError("rolling_exp is not currently support for dask arrays") import numbagg + if axis == (): return array.astype(np.float64) else: - return numbagg.move_exp_nanmean( - array, axis=axis, alpha=alpha) + return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) def _get_center_of_mass(comass, span, halflife, alpha): @@ -29,10 +29,10 @@ def _get_center_of_mass(comass, span, halflife, alpha): See licenses/PANDAS_LICENSE for the function's license """ from pandas.core import common as com + valid_count = com.count_not_none(comass, span, halflife, alpha) if valid_count > 1: - raise ValueError("comass, span, halflife, and alpha " - "are mutually exclusive") + raise ValueError("comass, span, halflife, and alpha " "are mutually exclusive") # Convert to center of mass; domain checks ensure 0 < alpha <= 1 if comass is not None: @@ -41,7 +41,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): elif span is not None: if span < 1: raise ValueError("span must satisfy: span >= 1") - comass = (span - 1) / 2. + comass = (span - 1) / 2.0 elif halflife is not None: if halflife <= 0: raise ValueError("halflife must satisfy: halflife > 0") @@ -83,7 +83,7 @@ class RollingExp: RollingExp : type of input argument """ # noqa - def __init__(self, obj, windows, window_type='span'): + def __init__(self, obj, windows, window_type="span"): self.obj = obj dim, window = next(iter(windows.items())) self.dim = dim @@ -102,5 +102,4 @@ def mean(self): Dimensions without coordinates: x """ - return self.obj.reduce( - move_exp_nanmean, dim=self.dim, alpha=self.alpha) + return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index b3e19aebcbf..4541049d7e1 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -8,9 +8,23 @@ import warnings from collections import OrderedDict from typing import ( - AbstractSet, Any, Callable, Container, Dict, Hashable, Iterable, Iterator, - Mapping, MutableMapping, MutableSet, Optional, Sequence, Tuple, TypeVar, - cast) + AbstractSet, + Any, + Callable, + Container, + Dict, + Hashable, + Iterable, + Iterator, + Mapping, + MutableMapping, + MutableSet, + Optional, + Sequence, + Tuple, + TypeVar, + cast, +) import numpy as np import pandas as pd @@ -18,29 +32,33 @@ from .pycompat import dask_array_type -K = TypeVar('K') -V = TypeVar('V') -T = TypeVar('T') +K = TypeVar("K") +V = TypeVar("V") +T = TypeVar("T") def _check_inplace(inplace: Optional[bool], default: bool = False) -> bool: if inplace is None: inplace = default else: - warnings.warn('The inplace argument has been deprecated and will be ' - 'removed in a future version of xarray.', - FutureWarning, stacklevel=3) + warnings.warn( + "The inplace argument has been deprecated and will be " + "removed in a future version of xarray.", + FutureWarning, + stacklevel=3, + ) return inplace def alias_message(old_name: str, new_name: str) -> str: - return '%s has been deprecated. Use %s instead.' % (old_name, new_name) + return "%s has been deprecated. Use %s instead." % (old_name, new_name) def alias_warning(old_name: str, new_name: str, stacklevel: int = 3) -> None: - warnings.warn(alias_message(old_name, new_name), FutureWarning, - stacklevel=stacklevel) + warnings.warn( + alias_message(old_name, new_name), FutureWarning, stacklevel=stacklevel + ) def alias(obj: Callable[..., T], old_name: str) -> Callable[..., T]: @@ -50,6 +68,7 @@ def alias(obj: Callable[..., T], old_name: str) -> Callable[..., T]: def wrapper(*args, **kwargs): alias_warning(old_name, obj.__name__) return obj(*args, **kwargs) + wrapper.__doc__ = alias_message(old_name, obj.__name__) return wrapper @@ -57,7 +76,7 @@ def wrapper(*args, **kwargs): def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: from ..coding.cftimeindex import CFTimeIndex - if len(index) > 0 and index.dtype == 'O': + if len(index) > 0 and index.dtype == "O": try: return CFTimeIndex(index) except (ImportError, TypeError): @@ -77,19 +96,19 @@ def safe_cast_to_index(array: Any) -> pd.Index: """ if isinstance(array, pd.Index): index = array - elif hasattr(array, 'to_index'): + elif hasattr(array, "to_index"): index = array.to_index() else: kwargs = {} - if hasattr(array, 'dtype') and array.dtype.kind == 'O': - kwargs['dtype'] = object + if hasattr(array, "dtype") and array.dtype.kind == "O": + kwargs["dtype"] = object index = pd.Index(np.asarray(array), **kwargs) return _maybe_cast_to_cftimeindex(index) -def multiindex_from_product_levels(levels: Sequence[pd.Index], - names: Sequence[str] = None - ) -> pd.MultiIndex: +def multiindex_from_product_levels( + levels: Sequence[pd.Index], names: Sequence[str] = None +) -> pd.MultiIndex: """Creating a MultiIndex from a product without refactorizing levels. Keeping levels the same gives back the original labels when we unstack. @@ -106,10 +125,10 @@ def multiindex_from_product_levels(levels: Sequence[pd.Index], pandas.MultiIndex """ if any(not isinstance(lev, pd.Index) for lev in levels): - raise TypeError('levels must be a list of pd.Index objects') + raise TypeError("levels must be a list of pd.Index objects") split_labels, levels = zip(*[lev.factorize() for lev in levels]) - labels_mesh = np.meshgrid(*split_labels, indexing='ij') + labels_mesh = np.meshgrid(*split_labels, indexing="ij") labels = [x.ravel() for x in labels_mesh] return pd.MultiIndex(levels, labels, sortorder=0, names=names) @@ -134,14 +153,17 @@ def equivalent(first: T, second: T) -> bool: """ # TODO: refactor to avoid circular import from . import duck_array_ops + if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): return duck_array_ops.array_equiv(first, second) elif isinstance(first, list) or isinstance(second, list): return list_equiv(first, second) else: - return ((first is second) or - (first == second) or - (pd.isnull(first) and pd.isnull(second))) + return ( + (first is second) + or (first == second) + or (pd.isnull(first) and pd.isnull(second)) + ) def list_equiv(first, second): @@ -163,9 +185,11 @@ def peek_at(iterable: Iterable[T]) -> Tuple[T, Iterator[T]]: return peek, itertools.chain([peek], gen) -def update_safety_check(first_dict: MutableMapping[K, V], - second_dict: Mapping[K, V], - compat: Callable[[V, V], bool] = equivalent) -> None: +def update_safety_check( + first_dict: MutableMapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> None: """Check the safety of updating one dictionary with another. Raises ValueError if dictionaries have non-compatible values for any key, @@ -183,14 +207,17 @@ def update_safety_check(first_dict: MutableMapping[K, V], """ for k, v in second_dict.items(): if k in first_dict and not compat(v, first_dict[k]): - raise ValueError('unsafe to merge dictionaries without ' - 'overriding values; conflicting key %r' % k) + raise ValueError( + "unsafe to merge dictionaries without " + "overriding values; conflicting key %r" % k + ) -def remove_incompatible_items(first_dict: MutableMapping[K, V], - second_dict: Mapping[K, V], - compat: Callable[[V, V], bool] = equivalent - ) -> None: +def remove_incompatible_items( + first_dict: MutableMapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> None: """Remove incompatible items from the first dictionary in-place. Items are retained if their keys are found in both dictionaries and the @@ -210,24 +237,28 @@ def remove_incompatible_items(first_dict: MutableMapping[K, V], def is_dict_like(value: Any) -> bool: - return hasattr(value, 'keys') and hasattr(value, '__getitem__') + return hasattr(value, "keys") and hasattr(value, "__getitem__") def is_full_slice(value: Any) -> bool: return isinstance(value, slice) and value == slice(None) -def either_dict_or_kwargs(pos_kwargs: Optional[Mapping[Hashable, T]], - kw_kwargs: Mapping[str, T], - func_name: str - ) -> Mapping[Hashable, T]: +def either_dict_or_kwargs( + pos_kwargs: Optional[Mapping[Hashable, T]], + kw_kwargs: Mapping[str, T], + func_name: str, +) -> Mapping[Hashable, T]: if pos_kwargs is not None: if not is_dict_like(pos_kwargs): - raise ValueError('the first argument to .%s must be a dictionary' - % func_name) + raise ValueError( + "the first argument to .%s must be a dictionary" % func_name + ) if kw_kwargs: - raise ValueError('cannot specify both keyword and positional ' - 'arguments to .%s' % func_name) + raise ValueError( + "cannot specify both keyword and positional " + "arguments to .%s" % func_name + ) return pos_kwargs else: # Need an explicit cast to appease mypy due to invariance; see @@ -241,10 +272,12 @@ def is_scalar(value: Any) -> bool: Any non-iterable, string, or 0-D array """ return ( - getattr(value, 'ndim', None) == 0 or - isinstance(value, (str, bytes)) or not - (isinstance(value, (Iterable, ) + dask_array_type) or - hasattr(value, '__array_function__')) + getattr(value, "ndim", None) == 0 + or isinstance(value, (str, bytes)) + or not ( + isinstance(value, (Iterable,) + dask_array_type) + or hasattr(value, "__array_function__") + ) ) @@ -268,15 +301,17 @@ def to_0d_object_array(value: Any) -> np.ndarray: def to_0d_array(value: Any) -> np.ndarray: """Given a value, wrap it in a 0-D numpy.ndarray. """ - if np.isscalar(value) or (isinstance(value, np.ndarray) and - value.ndim == 0): + if np.isscalar(value) or (isinstance(value, np.ndarray) and value.ndim == 0): return np.array(value) else: return to_0d_object_array(value) -def dict_equiv(first: Mapping[K, V], second: Mapping[K, V], - compat: Callable[[V, V], bool] = equivalent) -> bool: +def dict_equiv( + first: Mapping[K, V], + second: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> bool: """Test equivalence of two dict-like objects. If any of the values are numpy arrays, compare them correctly. @@ -302,10 +337,11 @@ def dict_equiv(first: Mapping[K, V], second: Mapping[K, V], return True -def ordered_dict_intersection(first_dict: Mapping[K, V], - second_dict: Mapping[K, V], - compat: Callable[[V, V], bool] = equivalent - ) -> MutableMapping[K, V]: +def ordered_dict_intersection( + first_dict: Mapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> MutableMapping[K, V]: """Return the intersection of two dictionaries as a new OrderedDict. Items are retained if their keys are found in both dictionaries and the @@ -334,7 +370,8 @@ class Frozen(Mapping[K, V]): immutable. If you really want to modify the mapping, the mutable version is saved under the `mapping` attribute. """ - __slots__ = ['mapping'] + + __slots__ = ["mapping"] def __init__(self, mapping: Mapping[K, V]): self.mapping = mapping @@ -352,7 +389,7 @@ def __contains__(self, key: object) -> bool: return key in self.mapping def __repr__(self) -> str: - return '%s(%r)' % (type(self).__name__, self.mapping) + return "%s(%r)" % (type(self).__name__, self.mapping) def FrozenOrderedDict(*args, **kwargs) -> Frozen: @@ -364,7 +401,8 @@ class SortedKeysDict(MutableMapping[K, V]): items in sorted order by key but is otherwise equivalent to the underlying mapping. """ - __slots__ = ['mapping'] + + __slots__ = ["mapping"] def __init__(self, mapping: MutableMapping[K, V] = None): self.mapping = {} if mapping is None else mapping @@ -388,7 +426,7 @@ def __contains__(self, key: object) -> bool: return key in self.mapping def __repr__(self) -> str: - return '%s(%r)' % (type(self).__name__, self.mapping) + return "%s(%r)" % (type(self).__name__, self.mapping) class OrderedSet(MutableSet[T]): @@ -397,6 +435,7 @@ class OrderedSet(MutableSet[T]): The API matches the builtin set, but it preserves insertion order of elements, like an OrderedDict. """ + def __init__(self, values: AbstractSet[T] = None): self._ordered_dict = OrderedDict() # type: MutableMapping[T, None] if values is not None: @@ -429,13 +468,14 @@ def update(self, values: AbstractSet[T]) -> None: self |= values # type: ignore def __repr__(self) -> str: - return '%s(%r)' % (type(self).__name__, list(self)) + return "%s(%r)" % (type(self).__name__, list(self)) class NdimSizeLenMixin: """Mixin class that extends a class that defines a ``shape`` property to one that also defines ``ndim``, ``size`` and ``__len__``. """ + @property def ndim(self: Any) -> int: return len(self.shape) @@ -449,7 +489,7 @@ def __len__(self: Any) -> int: try: return self.shape[0] except IndexError: - raise TypeError('len() of unsized object') + raise TypeError("len() of unsized object") class NDArrayMixin(NdimSizeLenMixin): @@ -459,6 +499,7 @@ class NDArrayMixin(NdimSizeLenMixin): A subclass should set the `array` property and override one or more of `dtype`, `shape` and `__getitem__`. """ + @property def dtype(self: Any) -> np.dtype: return self.array.dtype @@ -471,13 +512,14 @@ def __getitem__(self: Any, key): return self.array[key] def __repr__(self: Any) -> str: - return '%s(array=%r)' % (type(self).__name__, self.array) + return "%s(array=%r)" % (type(self).__name__, self.array) class ReprObject: """Object that prints as the given value, for use with sentinel values. """ - __slots__ = ('_value', ) + + __slots__ = ("_value",) def __init__(self, value: str): self._value = value @@ -507,12 +549,12 @@ def close_on_error(f): def is_remote_uri(path: str) -> bool: - return bool(re.search(r'^https?\://', path)) + return bool(re.search(r"^https?\://", path)) def is_grib_path(path: str) -> bool: _, ext = os.path.splitext(path) - return ext in ['.grib', '.grb', '.grib2', '.grb2'] + return ext in [".grib", ".grb", ".grib2", ".grb2"] def is_uniform_spaced(arr, **kwargs) -> bool: @@ -561,15 +603,16 @@ def ensure_us_time_resolution(val): """Convert val out of numpy time, for use in to_dict. Needed because of numpy bug GH#7619""" if np.issubdtype(val.dtype, np.datetime64): - val = val.astype('datetime64[us]') + val = val.astype("datetime64[us]") elif np.issubdtype(val.dtype, np.timedelta64): - val = val.astype('timedelta64[us]') + val = val.astype("timedelta64[us]") return val class HiddenKeyDict(MutableMapping[K, V]): """Acts like a normal dictionary, but hides certain keys. """ + # ``__init__`` method required to create instance from class. def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]): @@ -578,7 +621,7 @@ def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]): def _raise_if_hidden(self, key: K) -> None: if key in self._hidden_keys: - raise KeyError('Key `%r` is hidden.' % key) + raise KeyError("Key `%r` is hidden." % key) # The next five methods are requirements of the ABC. def __setitem__(self, key: K, value: V) -> None: @@ -617,5 +660,5 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: -> ['__rolling'] """ while new_dim in dims: - new_dim = '_' + str(new_dim) + new_dim = "_" + str(new_dim) return new_dim diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 85f26d85cd4..41f2a64ed55 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -10,17 +10,23 @@ import xarray as xr # only for Dataset and DataArray -from . import ( - arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils) +from . import arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .indexing import ( - BasicIndexer, OuterIndexer, PandasIndexAdapter, VectorizedIndexer, - as_indexable) + BasicIndexer, + OuterIndexer, + PandasIndexAdapter, + VectorizedIndexer, + as_indexable, +) from .options import _get_keep_attrs from .pycompat import dask_array_type, integer_types from .npcompat import IS_NEP18_ACTIVE from .utils import ( - OrderedSet, decode_numpy_dict_values, either_dict_or_kwargs, - ensure_us_time_resolution) + OrderedSet, + decode_numpy_dict_values, + either_dict_or_kwargs, + ensure_us_time_resolution, +) try: import dask.array as da @@ -29,7 +35,9 @@ NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( - indexing.ExplicitlyIndexed, pd.Index) + dask_array_type + indexing.ExplicitlyIndexed, + pd.Index, +) + dask_array_type # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore @@ -37,11 +45,12 @@ class MissingDimensionsError(ValueError): """Error class used when we can't safely guess a dimension name. """ + # inherits from ValueError for backward compatibility # TODO: move this to an xarray.exceptions module? -def as_variable(obj, name=None) -> 'Union[Variable, IndexVariable]': +def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]": """Convert an object into a Variable. Parameters @@ -82,36 +91,41 @@ def as_variable(obj, name=None) -> 'Union[Variable, IndexVariable]': obj = Variable(*obj) except (TypeError, ValueError) as error: # use .format() instead of % because it handles tuples consistently - raise error.__class__('Could not convert tuple of form ' - '(dims, data[, attrs, encoding]): ' - '{} to Variable.'.format(obj)) + raise error.__class__( + "Could not convert tuple of form " + "(dims, data[, attrs, encoding]): " + "{} to Variable.".format(obj) + ) elif utils.is_scalar(obj): obj = Variable([], obj) elif isinstance(obj, (pd.Index, IndexVariable)) and obj.name is not None: obj = Variable(obj.name, obj) elif isinstance(obj, (set, dict)): - raise TypeError( - "variable %r has invalid type %r" % (name, type(obj))) + raise TypeError("variable %r has invalid type %r" % (name, type(obj))) elif name is not None: data = as_compatible_data(obj) if data.ndim != 1: raise MissingDimensionsError( - 'cannot set variable %r with %r-dimensional data ' - 'without explicit dimension names. Pass a tuple of ' - '(dims, data) instead.' % (name, data.ndim)) + "cannot set variable %r with %r-dimensional data " + "without explicit dimension names. Pass a tuple of " + "(dims, data) instead." % (name, data.ndim) + ) obj = Variable(name, data, fastpath=True) else: - raise TypeError('unable to convert object into a variable without an ' - 'explicit list of dimensions: %r' % obj) + raise TypeError( + "unable to convert object into a variable without an " + "explicit list of dimensions: %r" % obj + ) if name is not None and name in obj.dims: # convert the Variable into an Index if obj.ndim != 1: raise MissingDimensionsError( - '%r has more than 1-dimension and the same name as one of its ' - 'dimensions %r. xarray disallows such variables because they ' - 'conflict with the coordinates used to label ' - 'dimensions.' % (name, obj.dims)) + "%r has more than 1-dimension and the same name as one of its " + "dimensions %r. xarray disallows such variables because they " + "conflict with the coordinates used to label " + "dimensions." % (name, obj.dims) + ) obj = obj.to_index_variable() return obj @@ -148,7 +162,7 @@ def as_compatible_data(data, fastpath=False): Finally, wrap it up with an adapter if necessary. """ - if fastpath and getattr(data, 'ndim', 0) > 0: + if fastpath and getattr(data, "ndim", 0) > 0: # can't use fastpath (yet) for scalars return _maybe_wrap_data(data) @@ -163,13 +177,13 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, pd.Timestamp): # TODO: convert, handle datetime objects, too - data = np.datetime64(data.value, 'ns') + data = np.datetime64(data.value, "ns") if isinstance(data, timedelta): - data = np.timedelta64(getattr(data, 'value', data), 'ns') + data = np.timedelta64(getattr(data, "value", data), "ns") # we don't want nested self-described arrays - data = getattr(data, 'values', data) + data = getattr(data, "values", data) if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) @@ -181,27 +195,28 @@ def as_compatible_data(data, fastpath=False): data = np.asarray(data) if not isinstance(data, np.ndarray): - if hasattr(data, '__array_function__'): + if hasattr(data, "__array_function__"): if IS_NEP18_ACTIVE: return data else: raise TypeError( - 'Got an NumPy-like array type providing the ' - '__array_function__ protocol but NEP18 is not enabled. ' - 'Check that numpy >= v1.16 and that the environment ' + "Got an NumPy-like array type providing the " + "__array_function__ protocol but NEP18 is not enabled. " + "Check that numpy >= v1.16 and that the environment " 'variable "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION" is set to ' - '"1"') + '"1"' + ) # validate whether the data is valid data types data = np.asarray(data) if isinstance(data, np.ndarray): - if data.dtype.kind == 'O': + if data.dtype.kind == "O": data = _possibly_convert_objects(data) - elif data.dtype.kind == 'M': - data = np.asarray(data, 'datetime64[ns]') - elif data.dtype.kind == 'm': - data = np.asarray(data, 'timedelta64[ns]') + elif data.dtype.kind == "M": + data = np.asarray(data, "datetime64[ns]") + elif data.dtype.kind == "m": + data = np.asarray(data, "timedelta64[ns]") return _maybe_wrap_data(data) @@ -222,15 +237,16 @@ def _as_array_or_item(data): """ data = np.asarray(data) if data.ndim == 0: - if data.dtype.kind == 'M': - data = np.datetime64(data, 'ns') - elif data.dtype.kind == 'm': - data = np.timedelta64(data, 'ns') + if data.dtype.kind == "M": + data = np.datetime64(data, "ns") + elif data.dtype.kind == "m": + data = np.timedelta64(data, "ns") return data -class Variable(common.AbstractArray, arithmetic.SupportsArithmetic, - utils.NdimSizeLenMixin): +class Variable( + common.AbstractArray, arithmetic.SupportsArithmetic, utils.NdimSizeLenMixin +): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -294,14 +310,14 @@ def nbytes(self): @property def _in_memory(self): - return (isinstance(self._data, (np.ndarray, np.number, - PandasIndexAdapter)) or - (isinstance(self._data, indexing.MemoryCachedArray) and - isinstance(self._data.array, indexing.NumpyIndexingAdapter))) + return isinstance(self._data, (np.ndarray, np.number, PandasIndexAdapter)) or ( + isinstance(self._data, indexing.MemoryCachedArray) + and isinstance(self._data.array, indexing.NumpyIndexingAdapter) + ) @property def data(self): - if hasattr(self._data, '__array_function__'): + if hasattr(self._data, "__array_function__"): return self._data else: return self.values @@ -310,8 +326,7 @@ def data(self): def data(self, data): data = as_compatible_data(data) if data.shape != self.shape: - raise ValueError( - "replacement data must match the Variable's shape") + raise ValueError("replacement data must match the Variable's shape") self._data = data def load(self, **kwargs): @@ -333,7 +348,7 @@ def load(self, **kwargs): """ if isinstance(self._data, dask_array_type): self._data = as_compatible_data(self._data.compute(**kwargs)) - elif not hasattr(self._data, '__array_function__'): + elif not hasattr(self._data, "__array_function__"): self._data = np.asarray(self._data) return self @@ -380,13 +395,17 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func, array_args, self._dims, - self._attrs, self._encoding) + return ( + self._dask_finalize, + (array_func, array_args, self._dims, self._attrs, self._encoding), + ) def __dask_postpersist__(self): array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func, array_args, self._dims, - self._attrs, self._encoding) + return ( + self._dask_finalize, + (array_func, array_args, self._dims, self._attrs, self._encoding), + ) @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): @@ -407,17 +426,19 @@ def values(self, values): def to_base_variable(self): """Return this variable as a base xarray.Variable""" - return Variable(self.dims, self._data, self._attrs, - encoding=self._encoding, fastpath=True) + return Variable( + self.dims, self._data, self._attrs, encoding=self._encoding, fastpath=True + ) - to_variable = utils.alias(to_base_variable, 'to_variable') + to_variable = utils.alias(to_base_variable, "to_variable") def to_index_variable(self): """Return this variable as an xarray.IndexVariable""" - return IndexVariable(self.dims, self._data, self._attrs, - encoding=self._encoding, fastpath=True) + return IndexVariable( + self.dims, self._data, self._attrs, encoding=self._encoding, fastpath=True + ) - to_coord = utils.alias(to_index_variable, 'to_coord') + to_coord = utils.alias(to_index_variable, "to_coord") def to_index(self): """Convert this variable to a pandas.Index""" @@ -425,12 +446,11 @@ def to_index(self): def to_dict(self, data=True): """Dictionary representation of variable.""" - item = {'dims': self.dims, - 'attrs': decode_numpy_dict_values(self.attrs)} + item = {"dims": self.dims, "attrs": decode_numpy_dict_values(self.attrs)} if data: - item['data'] = ensure_us_time_resolution(self.values).tolist() + item["data"] = ensure_us_time_resolution(self.values).tolist() else: - item.update({'dtype': str(self.dtype), 'shape': self.shape}) + item.update({"dtype": str(self.dtype), "shape": self.shape}) return item @property @@ -448,9 +468,10 @@ def _parse_dimensions(self, dims): dims = (dims,) dims = tuple(dims) if len(dims) != self.ndim: - raise ValueError('dimensions %s must have the same length as the ' - 'number of data dimensions, ndim=%s' - % (dims, self.ndim)) + raise ValueError( + "dimensions %s must have the same length as the " + "number of data dimensions, ndim=%s" % (dims, self.ndim) + ) return dims def _item_key_to_tuple(self, key): @@ -485,12 +506,12 @@ def _broadcast_indexes(self, key): key = indexing.expanded_indexer(key, self.ndim) # Convert a scalar Variable to an integer key = tuple( - k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k - for k in key) + k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key + ) # Convert a 0d-array to an integer key = tuple( - k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k - for k in key) + k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key + ) if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key): return self._broadcast_indexes_basic(key) @@ -518,8 +539,9 @@ def _broadcast_indexes(self, key): return self._broadcast_indexes_vectorized(key) def _broadcast_indexes_basic(self, key): - dims = tuple(dim for k, dim in zip(key, self.dims) - if not isinstance(k, integer_types)) + dims = tuple( + dim for k, dim in zip(key, self.dims) if not isinstance(k, integer_types) + ) return dims, BasicIndexer(key), None def _validate_indexers(self, key): @@ -533,29 +555,34 @@ def _validate_indexers(self, key): if k.ndim > 1: raise IndexError( "Unlabeled multi-dimensional array cannot be " - "used for indexing: {}".format(k)) - if k.dtype.kind == 'b': + "used for indexing: {}".format(k) + ) + if k.dtype.kind == "b": if self.shape[self.get_axis_num(dim)] != len(k): raise IndexError( "Boolean array size {:d} is used to index array " - "with shape {:s}." - .format(len(k), str(self.shape)) + "with shape {:s}.".format(len(k), str(self.shape)) ) if k.ndim > 1: - raise IndexError("{}-dimensional boolean indexing is " - "not supported. ".format(k.ndim)) - if getattr(k, 'dims', (dim, )) != (dim, ): + raise IndexError( + "{}-dimensional boolean indexing is " + "not supported. ".format(k.ndim) + ) + if getattr(k, "dims", (dim,)) != (dim,): raise IndexError( "Boolean indexer should be unlabeled or on the " "same dimension to the indexed array. Indexer is " - "on {:s} but the target dimension is {:s}." - .format(str(k.dims), dim) + "on {:s} but the target dimension is {:s}.".format( + str(k.dims), dim + ) ) def _broadcast_indexes_outer(self, key): - dims = tuple(k.dims[0] if isinstance(k, Variable) else dim - for k, dim in zip(key, self.dims) - if not isinstance(k, integer_types)) + dims = tuple( + k.dims[0] if isinstance(k, Variable) else dim + for k, dim in zip(key, self.dims) + if not isinstance(k, integer_types) + ) new_key = [] for k in key: @@ -563,7 +590,7 @@ def _broadcast_indexes_outer(self, key): k = k.data if not isinstance(k, BASIC_INDEXING_TYPES): k = np.asarray(k) - if k.dtype.kind == 'b': + if k.dtype.kind == "b": (k,) = np.nonzero(k) new_key.append(k) @@ -574,8 +601,7 @@ def _nonzero(self): # TODO we should replace dask's native nonzero # after https://github.com/dask/dask/issues/1076 is implemented. nonzeros = np.nonzero(self.data) - return tuple(Variable((dim), nz) for nz, dim - in zip(nonzeros, self.dims)) + return tuple(Variable((dim), nz) for nz, dim in zip(nonzeros, self.dims)) def _broadcast_indexes_vectorized(self, key): variables = [] @@ -584,9 +610,12 @@ def _broadcast_indexes_vectorized(self, key): if isinstance(value, slice): out_dims_set.add(dim) else: - variable = (value if isinstance(value, Variable) else - as_variable(value, name=dim)) - if variable.dtype.kind == 'b': # boolean indexing case + variable = ( + value + if isinstance(value, Variable) + else as_variable(value, name=dim) + ) + if variable.dtype.kind == "b": # boolean indexing case (variable,) = variable._nonzero() variables.append(variable) @@ -624,8 +653,7 @@ def _broadcast_indexes_vectorized(self, key): slice_positions.add(new_position) if slice_positions: - new_order = [i for i in range(len(out_dims)) - if i not in slice_positions] + new_order = [i for i in range(len(out_dims)) if i not in slice_positions] else: new_order = None @@ -647,15 +675,13 @@ def __getitem__(self, key): dims, indexer, new_order = self._broadcast_indexes(key) data = as_indexable(self._data)[indexer] if new_order: - data = duck_array_ops.moveaxis( - data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self, dims, data): """Used by IndexVariable to return IndexVariable objects when possible. """ - return type(self)(dims, data, self._attrs, self._encoding, - fastpath=True) + return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) def _getitem_with_mask(self, key, fill_value=dtypes.NA): """Index this Variable with -1 remapped to fill_value.""" @@ -682,18 +708,17 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): actual_indexer = indexer data = as_indexable(self._data)[actual_indexer] - chunks_hint = getattr(data, 'chunks', None) + chunks_hint = getattr(data, "chunks", None) mask = indexing.create_mask(indexer, self.shape, chunks_hint) data = duck_array_ops.where(mask, fill_value, data) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. mask = indexing.create_mask(indexer, self.shape) - data = np.broadcast_to(fill_value, getattr(mask, 'shape', ())) + data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) if new_order: - data = duck_array_ops.moveaxis( - data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def __setitem__(self, key, value): @@ -708,28 +733,27 @@ def __setitem__(self, key, value): value = as_compatible_data(value) if value.ndim > len(dims): raise ValueError( - 'shape mismatch: value array of shape %s could not be ' - 'broadcast to indexing result with %s dimensions' - % (value.shape, len(dims))) + "shape mismatch: value array of shape %s could not be " + "broadcast to indexing result with %s dimensions" + % (value.shape, len(dims)) + ) if value.ndim == 0: value = Variable((), value) else: - value = Variable(dims[-value.ndim:], value) + value = Variable(dims[-value.ndim :], value) # broadcast to become assignable value = value.set_dims(dims).data if new_order: value = duck_array_ops.asarray(value) - value = value[(len(dims) - value.ndim) * (np.newaxis,) - + (Ellipsis,)] - value = duck_array_ops.moveaxis( - value, new_order, range(len(new_order))) + value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] + value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexable[index_tuple] = value @property - def attrs(self) -> 'OrderedDict[Any, Any]': + def attrs(self) -> "OrderedDict[Any, Any]": """Dictionary of local attributes on this variable. """ if self._attrs is None: @@ -753,7 +777,7 @@ def encoding(self, value): try: self._encoding = dict(value) except ValueError: - raise ValueError('encoding must be castable to a dictionary') + raise ValueError("encoding must be castable to a dictionary") def copy(self, deep=True, data=None): """Returns a copy of this object. @@ -820,8 +844,9 @@ def copy(self, deep=True, data=None): data = indexing.MemoryCachedArray(data.array) if deep: - if (hasattr(data, '__array_function__') - or isinstance(data, dask_array_type)): + if hasattr(data, "__array_function__") or isinstance( + data, dask_array_type + ): data = data.copy() elif not isinstance(data, PandasIndexAdapter): # pandas.Index is immutable @@ -829,14 +854,16 @@ def copy(self, deep=True, data=None): else: data = as_compatible_data(data) if self.shape != data.shape: - raise ValueError("Data shape {} must match shape of object {}" - .format(data.shape, self.shape)) + raise ValueError( + "Data shape {} must match shape of object {}".format( + data.shape, self.shape + ) + ) # note: # dims is already an immutable tuple # attributes and encoding will be copied when the new Array is created - return type(self)(self.dims, data, self._attrs, self._encoding, - fastpath=True) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def __copy__(self): return self.copy(deep=False) @@ -855,7 +882,7 @@ def chunks(self): """Block dimensions for this array's data or None if it's not a dask array. """ - return getattr(self._data, 'chunks', None) + return getattr(self._data, "chunks", None) _array_counter = itertools.count() @@ -890,10 +917,7 @@ def chunk(self, chunks=None, name=None, lock=False): import dask.array as da if utils.is_dict_like(chunks): - chunks = { - self.get_axis_num(dim): chunk - for dim, chunk in chunks.items() - } + chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()} if chunks is None: chunks = self.chunks or self.shape @@ -903,28 +927,26 @@ def chunk(self, chunks=None, name=None, lock=False): data = data.rechunk(chunks) else: if utils.is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) - for n, s in enumerate(self.shape)) + chunks = tuple(chunks.get(n, s) for n, s in enumerate(self.shape)) # da.from_array works by using lazily indexing with a tuple of # slices. Using OuterIndexer is a pragmatic choice: dask does not # yet handle different indexing types in an explicit way: # https://github.com/dask/dask/issues/2883 data = indexing.ImplicitToExplicitIndexingAdapter( - data, indexing.OuterIndexer) + data, indexing.OuterIndexer + ) # For now, assume that all arrays that we wrap with dask (including # our lazily loaded backend array classes) should use NumPy array # operations. - if LooseVersion(dask.__version__) > '1.2.2': + if LooseVersion(dask.__version__) > "1.2.2": kwargs = dict(meta=np.ndarray) else: kwargs = dict() - data = da.from_array( - data, chunks, name=name, lock=lock, **kwargs) + data = da.from_array(data, chunks, name=name, lock=lock, **kwargs) - return type(self)(self.dims, data, self._attrs, self._encoding, - fastpath=True) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def isel(self, indexers=None, drop=False, **indexers_kwargs): """Return a new array indexed along the specified dimension(s). @@ -943,7 +965,7 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs): unless numpy fancy indexing was triggered by using an array indexer, in which case the data will be a copy. """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") invalid = [k for k in indexers if k not in self.dims] if invalid: @@ -1043,14 +1065,15 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): shifted : Variable Variable with the same dimensions and attributes but shifted data. """ - shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") result = self for dim, count in shifts.items(): result = result._shift_one_dim(dim, count, fill_value=fill_value) return result - def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, - **pad_widths_kwargs): + def pad_with_fill_value( + self, pad_widths=None, fill_value=dtypes.NA, **pad_widths_kwargs + ): """ Return a new Variable with paddings. @@ -1061,8 +1084,7 @@ def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, **pad_widths_kwargs: Keyword argument for pad_widths """ - pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs, - 'pad') + pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs, "pad") if fill_value is dtypes.NA: dtype, fill_value = dtypes.maybe_promote(self.dtype) @@ -1079,27 +1101,36 @@ def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, before_shape = list(array.shape) before_shape[axis] = pad[0] before_chunks = list(array.chunks) - before_chunks[axis] = (pad[0], ) + before_chunks[axis] = (pad[0],) after_shape = list(array.shape) after_shape[axis] = pad[1] after_chunks = list(array.chunks) - after_chunks[axis] = (pad[1], ) + after_chunks[axis] = (pad[1],) arrays = [] if pad[0] > 0: - arrays.append(da.full(before_shape, fill_value, - dtype=dtype, chunks=before_chunks)) + arrays.append( + da.full( + before_shape, fill_value, dtype=dtype, chunks=before_chunks + ) + ) arrays.append(array) if pad[1] > 0: - arrays.append(da.full(after_shape, fill_value, - dtype=dtype, chunks=after_chunks)) + arrays.append( + da.full( + after_shape, fill_value, dtype=dtype, chunks=after_chunks + ) + ) if len(arrays) > 1: array = da.concatenate(arrays, axis=axis) else: - pads = [(0, 0) if d not in pad_widths else pad_widths[d] - for d in self.dims] - array = np.pad(self.data.astype(dtype, copy=False), pads, - mode='constant', constant_values=fill_value) + pads = [(0, 0) if d not in pad_widths else pad_widths[d] for d in self.dims] + array = np.pad( + self.data.astype(dtype, copy=False), + pads, + mode="constant", + constant_values=fill_value, + ) return type(self)(self.dims, array) def _roll_one_dim(self, dim, count): @@ -1111,8 +1142,7 @@ def _roll_one_dim(self, dim, count): else: indices = [slice(None)] - arrays = [self[(slice(None),) * axis + (idx,)].data - for idx in indices] + arrays = [self[(slice(None),) * axis + (idx,)].data for idx in indices] data = duck_array_ops.concatenate(arrays, axis) @@ -1143,14 +1173,14 @@ def roll(self, shifts=None, **shifts_kwargs): shifted : Variable Variable with the same dimensions and attributes but rolled data. """ - shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll") result = self for dim, count in shifts.items(): result = result._roll_one_dim(dim, count) return result - def transpose(self, *dims) -> 'Variable': + def transpose(self, *dims) -> "Variable": """Return a new Variable object with transposed dimensions. Parameters @@ -1181,18 +1211,20 @@ def transpose(self, *dims) -> 'Variable': return self.copy(deep=False) data = as_indexable(self._data).transpose(axes) - return type(self)(dims, data, self._attrs, self._encoding, - fastpath=True) + return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) @property - def T(self) -> 'Variable': + def T(self) -> "Variable": return self.transpose() def expand_dims(self, *args): import warnings - warnings.warn('Variable.expand_dims is deprecated: use ' - 'Variable.set_dims instead', DeprecationWarning, - stacklevel=2) + + warnings.warn( + "Variable.expand_dims is deprecated: use " "Variable.set_dims instead", + DeprecationWarning, + stacklevel=2, + ) return self.expand_dims(*args) def set_dims(self, dims, shape=None): @@ -1220,12 +1252,13 @@ def set_dims(self, dims, shape=None): missing_dims = set(self.dims) - set(dims) if missing_dims: - raise ValueError('new dimensions %r must be a superset of ' - 'existing dimensions %r' % (dims, self.dims)) + raise ValueError( + "new dimensions %r must be a superset of " + "existing dimensions %r" % (dims, self.dims) + ) self_dims = set(self.dims) - expanded_dims = tuple( - d for d in dims if d not in self_dims) + self.dims + expanded_dims = tuple(d for d in dims if d not in self_dims) + self.dims if self.dims == expanded_dims: # don't use broadcast_to unless necessary so the result remains @@ -1236,20 +1269,22 @@ def set_dims(self, dims, shape=None): tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) else: - expanded_data = self.data[ - (None,) * (len(expanded_dims) - self.ndim)] + expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)] - expanded_var = Variable(expanded_dims, expanded_data, self._attrs, - self._encoding, fastpath=True) + expanded_var = Variable( + expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True + ) return expanded_var.transpose(*dims) def _stack_once(self, dims, new_dim): if not set(dims) <= set(self.dims): - raise ValueError('invalid existing dimensions: %s' % dims) + raise ValueError("invalid existing dimensions: %s" % dims) if new_dim in self.dims: - raise ValueError('cannot create a new dimension with the same ' - 'name as an existing dimension') + raise ValueError( + "cannot create a new dimension with the same " + "name as an existing dimension" + ) if len(dims) == 0: # don't stack @@ -1259,12 +1294,11 @@ def _stack_once(self, dims, new_dim): dim_order = other_dims + list(dims) reordered = self.transpose(*dim_order) - new_shape = reordered.shape[:len(other_dims)] + (-1,) + new_shape = reordered.shape[: len(other_dims)] + (-1,) new_data = reordered.data.reshape(new_shape) - new_dims = reordered.dims[:len(other_dims)] + (new_dim,) + new_dims = reordered.dims[: len(other_dims)] + (new_dim,) - return Variable(new_dims, new_data, self._attrs, self._encoding, - fastpath=True) + return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) def stack(self, dimensions=None, **dimensions_kwargs): """ @@ -1291,8 +1325,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): -------- Variable.unstack """ - dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, - 'stack') + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "stack") result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -1303,26 +1336,29 @@ def _unstack_once(self, dims, old_dim): new_dim_sizes = tuple(dims.values()) if old_dim not in self.dims: - raise ValueError('invalid existing dimension: %s' % old_dim) + raise ValueError("invalid existing dimension: %s" % old_dim) if set(new_dim_names).intersection(self.dims): - raise ValueError('cannot create a new dimension with the same ' - 'name as an existing dimension') + raise ValueError( + "cannot create a new dimension with the same " + "name as an existing dimension" + ) if np.prod(new_dim_sizes) != self.sizes[old_dim]: - raise ValueError('the product of the new dimension sizes must ' - 'equal the size of the old dimension') + raise ValueError( + "the product of the new dimension sizes must " + "equal the size of the old dimension" + ) other_dims = [d for d in self.dims if d != old_dim] dim_order = other_dims + [old_dim] reordered = self.transpose(*dim_order) - new_shape = reordered.shape[:len(other_dims)] + new_dim_sizes + new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes new_data = reordered.data.reshape(new_shape) - new_dims = reordered.dims[:len(other_dims)] + new_dim_names + new_dims = reordered.dims[: len(other_dims)] + new_dim_names - return Variable(new_dims, new_data, self._attrs, self._encoding, - fastpath=True) + return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) def unstack(self, dimensions=None, **dimensions_kwargs): """ @@ -1349,8 +1385,7 @@ def unstack(self, dimensions=None, **dimensions_kwargs): -------- Variable.stack """ - dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, - 'unstack') + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "unstack") result = self for old_dim, dims in dimensions.items(): result = result._unstack_once(dims, old_dim) @@ -1362,8 +1397,16 @@ def fillna(self, value): def where(self, cond, other=dtypes.NA): return ops.where_method(self, cond, other) - def reduce(self, func, dim=None, axis=None, - keep_attrs=None, keepdims=False, allow_lazy=False, **kwargs): + def reduce( + self, + func, + dim=None, + axis=None, + keep_attrs=None, + keepdims=False, + allow_lazy=False, + **kwargs + ): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1408,24 +1451,28 @@ def reduce(self, func, dim=None, axis=None, else: data = func(input_data, **kwargs) - if getattr(data, 'shape', ()) == self.shape: + if getattr(data, "shape", ()) == self.shape: dims = self.dims else: - removed_axes = (range(self.ndim) if axis is None - else np.atleast_1d(axis) % self.ndim) + removed_axes = ( + range(self.ndim) if axis is None else np.atleast_1d(axis) % self.ndim + ) if keepdims: # Insert np.newaxis for removed dims - slices = tuple(np.newaxis if i in removed_axes else - slice(None, None) for i in range(self.ndim)) - if getattr(data, 'shape', None) is None: + slices = tuple( + np.newaxis if i in removed_axes else slice(None, None) + for i in range(self.ndim) + ) + if getattr(data, "shape", None) is None: # Reduce has produced a scalar value, not an array-like data = np.asanyarray(data)[slices] else: data = data[slices] dims = self.dims else: - dims = [adim for n, adim in enumerate(self.dims) - if n not in removed_axes] + dims = [ + adim for n, adim in enumerate(self.dims) if n not in removed_axes + ] if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -1434,8 +1481,7 @@ def reduce(self, func, dim=None, axis=None, return Variable(dims, data, attrs=attrs) @classmethod - def concat(cls, variables, dim='concat_dim', positions=None, - shortcut=False): + def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): """Concatenate variables along a new or existing dimension. Parameters @@ -1482,8 +1528,7 @@ def concat(cls, variables, dim='concat_dim', positions=None, if positions is not None: # TODO: deprecate this option -- we don't need it for groupby # any more. - indices = nputils.inverse_permutation( - np.concatenate(positions)) + indices = nputils.inverse_permutation(np.concatenate(positions)) data = duck_array_ops.take(data, indices, axis=axis) else: axis = 0 @@ -1495,7 +1540,7 @@ def concat(cls, variables, dim='concat_dim', positions=None, if not shortcut: for var in variables: if var.dims != first_var.dims: - raise ValueError('inconsistent dimensions') + raise ValueError("inconsistent dimensions") utils.remove_incompatible_items(attrs, var.attrs) return cls(dims, data, attrs, encoding) @@ -1510,11 +1555,10 @@ def equals(self, other, equiv=duck_array_ops.array_equiv): This method is necessary because `v1 == v2` for Variables does element-wise comparisons (like numpy.ndarrays). """ - other = getattr(other, 'variable', other) + other = getattr(other, "variable", other) try: - return ( - self.dims == other.dims and - (self._data is other._data or equiv(self.data, other.data)) + return self.dims == other.dims and ( + self._data is other._data or equiv(self.data, other.data) ) except (TypeError, AttributeError): return False @@ -1536,8 +1580,7 @@ def identical(self, other): """Like equals, but also checks attributes. """ try: - return (utils.dict_equiv(self.attrs, other.attrs) - and self.equals(other)) + return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other) except (TypeError, AttributeError): return False @@ -1548,10 +1591,9 @@ def no_conflicts(self, other): Variables can thus still be equal if there are locations where either, or both, contain NaN values. """ - return self.broadcast_equals( - other, equiv=duck_array_ops.array_notnull_equiv) + return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv) - def quantile(self, q, dim=None, interpolation='linear'): + def quantile(self, q, dim=None, interpolation="linear"): """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -1590,9 +1632,11 @@ def quantile(self, q, dim=None, interpolation='linear'): DataArray.quantile """ if isinstance(self.data, dask_array_type): - raise TypeError("quantile does not work for arrays stored as dask " - "arrays. Load the data via .compute() or .load() " - "prior to calling this method.") + raise TypeError( + "quantile does not work for arrays stored as dask " + "arrays. Load the data via .compute() or .load() " + "prior to calling this method." + ) q = np.asarray(q, dtype=np.float64) @@ -1610,10 +1654,11 @@ def quantile(self, q, dim=None, interpolation='linear'): # only add the quantile dimension if q is array like if q.ndim != 0: - new_dims = ['quantile'] + new_dims + new_dims = ["quantile"] + new_dims - qs = np.nanpercentile(self.data, q * 100., axis=axis, - interpolation=interpolation) + qs = np.nanpercentile( + self.data, q * 100.0, axis=axis, interpolation=interpolation + ) return Variable(new_dims, qs) def rank(self, dim, pct=False): @@ -1645,20 +1690,23 @@ def rank(self, dim, pct=False): import bottleneck as bn if isinstance(self.data, dask_array_type): - raise TypeError("rank does not work for arrays stored as dask " - "arrays. Load the data via .compute() or .load() " - "prior to calling this method.") + raise TypeError( + "rank does not work for arrays stored as dask " + "arrays. Load the data via .compute() or .load() " + "prior to calling this method." + ) axis = self.get_axis_num(dim) - func = bn.nanrankdata if self.dtype.kind == 'f' else bn.rankdata + func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata ranked = func(self.data, axis=axis) if pct: count = np.sum(~np.isnan(self.data), axis=axis, keepdims=True) ranked /= count return Variable(self.dims, ranked) - def rolling_window(self, dim, window, window_dim, center=False, - fill_value=dtypes.NA): + def rolling_window( + self, dim, window, window_dim, center=False, fill_value=dtypes.NA + ): """ Make a rolling_window along dim and add a new_dim to the last place. @@ -1703,12 +1751,19 @@ def rolling_window(self, dim, window, window_dim, center=False, dtype = self.dtype array = self.data - new_dims = self.dims + (window_dim, ) - return Variable(new_dims, duck_array_ops.rolling_window( - array, axis=self.get_axis_num(dim), window=window, - center=center, fill_value=fill_value)) - - def coarsen(self, windows, func, boundary='exact', side='left'): + new_dims = self.dims + (window_dim,) + return Variable( + new_dims, + duck_array_ops.rolling_window( + array, + axis=self.get_axis_num(dim), + window=window, + center=center, + fill_value=fill_value, + ), + ) + + def coarsen(self, windows, func, boundary="exact", side="left"): """ Apply """ @@ -1721,7 +1776,7 @@ def coarsen(self, windows, func, boundary='exact', side='left'): name = func func = getattr(duck_array_ops, name, None) if func is None: - raise NameError('{} is not a valid method.'.format(name)) + raise NameError("{} is not a valid method.".format(name)) return type(self)(self.dims, func(reshaped, axis=axes), self._attrs) def _coarsen_reshape(self, windows, boundary, side): @@ -1740,29 +1795,30 @@ def _coarsen_reshape(self, windows, boundary, side): for d, window in windows.items(): if window <= 0: - raise ValueError('window must be > 0. Given {}'.format(window)) + raise ValueError("window must be > 0. Given {}".format(window)) variable = self for d, window in windows.items(): # trim or pad the object size = variable.shape[self._get_axis_num(d)] n = int(size / window) - if boundary[d] == 'exact': + if boundary[d] == "exact": if n * window != size: raise ValueError( - 'Could not coarsen a dimension of size {} with ' - 'window {}'.format(size, window)) - elif boundary[d] == 'trim': - if side[d] == 'left': + "Could not coarsen a dimension of size {} with " + "window {}".format(size, window) + ) + elif boundary[d] == "trim": + if side[d] == "left": variable = variable.isel({d: slice(0, window * n)}) else: excess = size - window * n variable = variable.isel({d: slice(excess, None)}) - elif boundary[d] == 'pad': # pad + elif boundary[d] == "pad": # pad pad = window * n - size if pad < 0: pad += window - if side[d] == 'left': + if side[d] == "left": pad_widths = {d: (0, pad)} else: pad_widths = {d: (pad, 0)} @@ -1770,7 +1826,8 @@ def _coarsen_reshape(self, windows, boundary, side): else: raise TypeError( "{} is invalid for boundary. Valid option is 'exact', " - "'trim' and 'pad'".format(boundary[d])) + "'trim' and 'pad'".format(boundary[d]) + ) shape = [] axes = [] @@ -1802,8 +1859,9 @@ def __array_wrap__(self, obj, context=None): def _unary_op(f): @functools.wraps(f) def func(self, *args, **kwargs): - with np.errstate(all='ignore'): + with np.errstate(all="ignore"): return self.__array_wrap__(f(self.data, *args, **kwargs)) + return func @staticmethod @@ -1815,12 +1873,15 @@ def func(self, other): self_data, other_data, dims = _broadcast_compat_data(self, other) keep_attrs = _get_keep_attrs(default=False) attrs = self._attrs if keep_attrs else None - with np.errstate(all='ignore'): - new_data = (f(self_data, other_data) - if not reflexive - else f(other_data, self_data)) + with np.errstate(all="ignore"): + new_data = ( + f(self_data, other_data) + if not reflexive + else f(other_data, self_data) + ) result = Variable(dims, new_data, attrs=attrs) return result + return func @staticmethod @@ -1828,14 +1889,14 @@ def _inplace_binary_op(f): @functools.wraps(f) def func(self, other): if isinstance(other, xr.Dataset): - raise TypeError('cannot add a Dataset to a Variable in-place') + raise TypeError("cannot add a Dataset to a Variable in-place") self_data, other_data, dims = _broadcast_compat_data(self, other) if dims != self.dims: - raise ValueError('dimensions cannot change for in-place ' - 'operations') - with np.errstate(all='ignore'): + raise ValueError("dimensions cannot change for in-place " "operations") + with np.errstate(all="ignore"): self.values = f(self_data, other_data) return self + return func def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): @@ -1843,7 +1904,8 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): See duck_array_ops.datetime_to_numeric """ numeric_array = duck_array_ops.datetime_to_numeric( - self.data, offset, datetime_unit, dtype) + self.data, offset, datetime_unit, dtype + ) return type(self)(self.dims, numeric_array, self._attrs) @@ -1864,8 +1926,7 @@ class IndexVariable(Variable): def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): super().__init__(dims, data, attrs, encoding, fastpath) if self.ndim != 1: - raise ValueError('%s objects must be 1-dimensional' % - type(self).__name__) + raise ValueError("%s objects must be 1-dimensional" % type(self).__name__) # Unlike in Variable, always eagerly load values into memory if not isinstance(self._data, PandasIndexAdapter): @@ -1887,19 +1948,17 @@ def chunk(self, chunks=None, name=None, lock=False): return self.copy(deep=False) def _finalize_indexing_result(self, dims, data): - if getattr(data, 'ndim', 0) != 1: + if getattr(data, "ndim", 0) != 1: # returns Variable rather than IndexVariable if multi-dimensional return Variable(dims, data, self._attrs, self._encoding) else: - return type(self)(dims, data, self._attrs, - self._encoding, fastpath=True) + return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) def __setitem__(self, key, value): - raise TypeError('%s values cannot be modified' % type(self).__name__) + raise TypeError("%s values cannot be modified" % type(self).__name__) @classmethod - def concat(cls, variables, dim='concat_dim', positions=None, - shortcut=False): + def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): """Specialized version of Variable.concat for IndexVariable objects. This exists because we want to avoid converting Index objects to NumPy @@ -1912,8 +1971,10 @@ def concat(cls, variables, dim='concat_dim', positions=None, first_var = variables[0] if any(not isinstance(v, cls) for v in variables): - raise TypeError('IndexVariable.concat requires that all input ' - 'variables be IndexVariable objects') + raise TypeError( + "IndexVariable.concat requires that all input " + "variables be IndexVariable objects" + ) indexes = [v._data.array for v in variables] @@ -1923,15 +1984,14 @@ def concat(cls, variables, dim='concat_dim', positions=None, data = indexes[0].append(indexes[1:]) if positions is not None: - indices = nputils.inverse_permutation( - np.concatenate(positions)) + indices = nputils.inverse_permutation(np.concatenate(positions)) data = data.take(indices) attrs = OrderedDict(first_var.attrs) if not shortcut: for var in variables: if var.dims != first_var.dims: - raise ValueError('inconsistent dimensions') + raise ValueError("inconsistent dimensions") utils.remove_incompatible_items(attrs, var.attrs) return cls(first_var.dims, data, attrs) @@ -1965,10 +2025,12 @@ def copy(self, deep=True, data=None): else: data = as_compatible_data(data) if self.shape != data.shape: - raise ValueError("Data shape {} must match shape of object {}" - .format(data.shape, self.shape)) - return type(self)(self.dims, data, self._attrs, - self._encoding, fastpath=True) + raise ValueError( + "Data shape {} must match shape of object {}".format( + data.shape, self.shape + ) + ) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def equals(self, other, equiv=None): # if equiv is specified, super up @@ -1976,10 +2038,9 @@ def equals(self, other, equiv=None): return super().equals(other, equiv) # otherwise use the native index equals, rather than looking at _data - other = getattr(other, 'variable', other) + other = getattr(other, "variable", other) try: - return (self.dims == other.dims - and self._data_equals(other)) + return self.dims == other.dims and self._data_equals(other) except (TypeError, AttributeError): return False @@ -1990,7 +2051,7 @@ def to_index_variable(self): """Return this variable as an xarray.IndexVariable""" return self - to_coord = utils.alias(to_index_variable, 'to_coord') + to_coord = utils.alias(to_index_variable, "to_coord") def to_index(self): """Convert this variable to a pandas.Index""" @@ -2001,8 +2062,10 @@ def to_index(self): if isinstance(index, pd.MultiIndex): # set default names for multi-index unnamed levels so that # we can safely rename dimension / coordinate later - valid_level_names = [name or '{}_level_{}'.format(self.dims[0], i) - for i, name in enumerate(index.names)] + valid_level_names = [ + name or "{}_level_{}".format(self.dims[0], i) + for i, name in enumerate(index.names) + ] index = index.set_names(valid_level_names) else: index = index.set_names(self.name) @@ -2032,11 +2095,11 @@ def name(self): @name.setter def name(self, value): - raise AttributeError('cannot modify name of IndexVariable in-place') + raise AttributeError("cannot modify name of IndexVariable in-place") # for backwards compatibility -Coordinate = utils.alias(IndexVariable, 'Coordinate') +Coordinate = utils.alias(IndexVariable, "Coordinate") def _unified_dims(variables): @@ -2045,15 +2108,19 @@ def _unified_dims(variables): for var in variables: var_dims = var.dims if len(set(var_dims)) < len(var_dims): - raise ValueError('broadcasting cannot handle duplicate ' - 'dimensions: %r' % list(var_dims)) + raise ValueError( + "broadcasting cannot handle duplicate " + "dimensions: %r" % list(var_dims) + ) for d, s in zip(var_dims, var.shape): if d not in all_dims: all_dims[d] = s elif all_dims[d] != s: - raise ValueError('operands cannot be broadcast together ' - 'with mismatched lengths for dimension %r: %s' - % (d, (all_dims[d], s))) + raise ValueError( + "operands cannot be broadcast together " + "with mismatched lengths for dimension %r: %s" + % (d, (all_dims[d], s)) + ) return all_dims @@ -2064,8 +2131,7 @@ def _broadcast_compat_variables(*variables): dimensions of size 1 instead of the the size of the broadcast dimension. """ dims = tuple(_unified_dims(variables)) - return tuple(var.set_dims(dims) if var.dims != dims else var - for var in variables) + return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables) def broadcast_variables(*variables): @@ -2080,13 +2146,13 @@ def broadcast_variables(*variables): """ dims_map = _unified_dims(variables) dims_tuple = tuple(dims_map) - return tuple(var.set_dims(dims_map) if var.dims != dims_tuple else var - for var in variables) + return tuple( + var.set_dims(dims_map) if var.dims != dims_tuple else var for var in variables + ) def _broadcast_compat_data(self, other): - if all(hasattr(other, attr) for attr - in ['dims', 'data', 'shape', 'encoding']): + if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]): # `other` satisfies the necessary Variable API for broadcast_variables new_self, new_other = _broadcast_compat_variables(self, other) self_data = new_self.data @@ -2100,7 +2166,7 @@ def _broadcast_compat_data(self, other): return self_data, other_data, dims -def concat(variables, dim='concat_dim', positions=None, shortcut=False): +def concat(variables, dim="concat_dim", positions=None, shortcut=False): """Concatenate variables along a new or existing dimension. Parameters @@ -2151,22 +2217,23 @@ def assert_unique_multiindex_level_names(variables): idx_level_names = var.to_index_variable().level_names if idx_level_names is not None: for n in idx_level_names: - level_names[n].append('%r (%s)' % (n, var_name)) + level_names[n].append("%r (%s)" % (n, var_name)) if idx_level_names: all_level_names.update(idx_level_names) for k, v in level_names.items(): if k in variables: - v.append('(%s)' % k) + v.append("(%s)" % k) duplicate_names = [v for v in level_names.values() if len(v) > 1] if duplicate_names: - conflict_str = '\n'.join([', '.join(v) for v in duplicate_names]) - raise ValueError('conflicting MultiIndex level name(s):\n%s' - % conflict_str) + conflict_str = "\n".join([", ".join(v) for v in duplicate_names]) + raise ValueError("conflicting MultiIndex level name(s):\n%s" % conflict_str) # Check confliction between level names and dimensions GH:2299 for k, v in variables.items(): for d in v.dims: if d in all_level_names: - raise ValueError('conflicting level / dimension names. {} ' - 'already exists as a level name.'.format(d)) + raise ValueError( + "conflicting level / dimension names. {} " + "already exists as a level name.".format(d) + ) diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index adda541c21d..c3333acf7f5 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -2,13 +2,13 @@ from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step __all__ = [ - 'plot', - 'line', - 'step', - 'contour', - 'contourf', - 'hist', - 'imshow', - 'pcolormesh', - 'FacetGrid', + "plot", + "line", + "step", + "contour", + "contourf", + "hist", + "imshow", + "pcolormesh", + "FacetGrid", ] diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index aa31780a983..176f0c504f6 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -6,8 +6,12 @@ from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid from .utils import ( - _add_colorbar, _is_numeric, _process_cmap_cbar_kwargs, get_axis, - label_from_attrs) + _add_colorbar, + _is_numeric, + _process_cmap_cbar_kwargs, + get_axis, + label_from_attrs, +) # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) @@ -15,43 +19,44 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): dvars = set(ds.variables.keys()) - error_msg = (' must be one of ({0:s})' - .format(', '.join(dvars))) + error_msg = " must be one of ({0:s})".format(", ".join(dvars)) if x not in dvars: - raise ValueError('x' + error_msg) + raise ValueError("x" + error_msg) if y not in dvars: - raise ValueError('y' + error_msg) + raise ValueError("y" + error_msg) if hue is not None and hue not in dvars: - raise ValueError('hue' + error_msg) + raise ValueError("hue" + error_msg) if hue: hue_is_numeric = _is_numeric(ds[hue].values) if hue_style is None: - hue_style = 'continuous' if hue_is_numeric else 'discrete' + hue_style = "continuous" if hue_is_numeric else "discrete" - if not hue_is_numeric and (hue_style == 'continuous'): - raise ValueError('Cannot create a colorbar for a non numeric' - ' coordinate: ' + hue) + if not hue_is_numeric and (hue_style == "continuous"): + raise ValueError( + "Cannot create a colorbar for a non numeric" " coordinate: " + hue + ) if add_guide is None or add_guide is True: - add_colorbar = True if hue_style == 'continuous' else False - add_legend = True if hue_style == 'discrete' else False + add_colorbar = True if hue_style == "continuous" else False + add_legend = True if hue_style == "discrete" else False else: add_colorbar = False add_legend = False else: if add_guide is True: - raise ValueError('Cannot set add_guide when hue is None.') + raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False - if hue_style is not None and hue_style not in ['discrete', 'continuous']: - raise ValueError("hue_style must be either None, 'discrete' " - "or 'continuous'.") + if hue_style is not None and hue_style not in ["discrete", "continuous"]: + raise ValueError( + "hue_style must be either None, 'discrete' " "or 'continuous'." + ) if hue: hue_label = label_from_attrs(ds[hue]) @@ -60,46 +65,44 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): hue_label = None hue = None - return {'add_colorbar': add_colorbar, - 'add_legend': add_legend, - 'hue_label': hue_label, - 'hue_style': hue_style, - 'xlabel': label_from_attrs(ds[x]), - 'ylabel': label_from_attrs(ds[y]), - 'hue': hue} + return { + "add_colorbar": add_colorbar, + "add_legend": add_legend, + "hue_label": hue_label, + "hue_style": hue_style, + "xlabel": label_from_attrs(ds[x]), + "ylabel": label_from_attrs(ds[y]), + "hue": hue, + } -def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, - size_mapping=None): +def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None): - broadcast_keys = ['x', 'y'] + broadcast_keys = ["x", "y"] to_broadcast = [ds[x], ds[y]] if hue: to_broadcast.append(ds[hue]) - broadcast_keys.append('hue') + broadcast_keys.append("hue") if markersize: to_broadcast.append(ds[markersize]) - broadcast_keys.append('size') + broadcast_keys.append("size") broadcasted = dict(zip(broadcast_keys, broadcast(*to_broadcast))) - data = {'x': broadcasted['x'], - 'y': broadcasted['y'], - 'hue': None, - 'sizes': None} + data = {"x": broadcasted["x"], "y": broadcasted["y"], "hue": None, "sizes": None} if hue: - data['hue'] = broadcasted['hue'] + data["hue"] = broadcasted["hue"] if markersize: - size = broadcasted['size'] + size = broadcasted["size"] if size_mapping is None: size_mapping = _parse_size(size, size_norm) - data['sizes'] = size.copy( - data=np.reshape(size_mapping.loc[size.values.ravel()].values, - size.shape)) + data["sizes"] = size.copy( + data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape) + ) return data @@ -128,8 +131,7 @@ def _parse_size(data, norm): elif isinstance(norm, tuple): norm = mpl.colors.Normalize(*norm) elif not isinstance(norm, mpl.colors.Normalize): - err = ("``size_norm`` must be None, tuple, " - "or Normalize object.") + err = "``size_norm`` must be None, tuple, " "or Normalize object." raise ValueError(err) norm.clip = True @@ -156,8 +158,10 @@ def __init__(self, dataset): self._ds = dataset def __call__(self, *args, **kwargs): - raise ValueError('Dataset.plot cannot be called directly. Use ' - 'an explicit plot method, e.g. ds.plot.scatter(...)') + raise ValueError( + "Dataset.plot cannot be called directly. Use " + "an explicit plot method, e.g. ds.plot.scatter(...)" + ) def _dsplot(plotfunc): @@ -239,89 +243,145 @@ def _dsplot(plotfunc): """ # Build on the original docstring - plotfunc.__doc__ = '%s\n%s' % (plotfunc.__doc__, commondoc) + plotfunc.__doc__ = "%s\n%s" % (plotfunc.__doc__, commondoc) @functools.wraps(plotfunc) - def newplotfunc(ds, x=None, y=None, hue=None, hue_style=None, - col=None, row=None, ax=None, figsize=None, size=None, - col_wrap=None, sharex=True, sharey=True, aspect=None, - subplot_kws=None, add_guide=None, cbar_kwargs=None, - cbar_ax=None, vmin=None, vmax=None, - norm=None, infer_intervals=None, center=None, levels=None, - robust=None, colors=None, extend=None, cmap=None, - **kwargs): - - _is_facetgrid = kwargs.pop('_is_facetgrid', False) + def newplotfunc( + ds, + x=None, + y=None, + hue=None, + hue_style=None, + col=None, + row=None, + ax=None, + figsize=None, + size=None, + col_wrap=None, + sharex=True, + sharey=True, + aspect=None, + subplot_kws=None, + add_guide=None, + cbar_kwargs=None, + cbar_ax=None, + vmin=None, + vmax=None, + norm=None, + infer_intervals=None, + center=None, + levels=None, + robust=None, + colors=None, + extend=None, + cmap=None, + **kwargs + ): + + _is_facetgrid = kwargs.pop("_is_facetgrid", False) if _is_facetgrid: # facetgrid call - meta_data = kwargs.pop('meta_data') + meta_data = kwargs.pop("meta_data") else: meta_data = _infer_meta_data(ds, x, y, hue, hue_style, add_guide) - hue_style = meta_data['hue_style'] + hue_style = meta_data["hue_style"] # handle facetgrids first if col or row: allargs = locals().copy() - allargs['plotfunc'] = globals()[plotfunc.__name__] - allargs['data'] = ds + allargs["plotfunc"] = globals()[plotfunc.__name__] + allargs["data"] = ds # TODO dcherian: why do I need to remove kwargs? - for arg in ['meta_data', 'kwargs', 'ds']: + for arg in ["meta_data", "kwargs", "ds"]: del allargs[arg] - return _easy_facetgrid(kind='dataset', **allargs, **kwargs) + return _easy_facetgrid(kind="dataset", **allargs, **kwargs) - figsize = kwargs.pop('figsize', None) + figsize = kwargs.pop("figsize", None) ax = get_axis(figsize, size, aspect, ax) - if hue_style == 'continuous' and hue is not None: + if hue_style == "continuous" and hue is not None: if _is_facetgrid: - cbar_kwargs = meta_data['cbar_kwargs'] - cmap_params = meta_data['cmap_params'] + cbar_kwargs = meta_data["cbar_kwargs"] + cmap_params = meta_data["cmap_params"] else: cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - plotfunc, ds[hue].values, **locals()) + plotfunc, ds[hue].values, **locals() + ) # subset that can be passed to scatter, hist2d cmap_params_subset = dict( - (vv, cmap_params[vv]) - for vv in ['vmin', 'vmax', 'norm', 'cmap']) + (vv, cmap_params[vv]) for vv in ["vmin", "vmax", "norm", "cmap"] + ) else: cmap_params_subset = {} - primitive = plotfunc(ds=ds, x=x, y=y, hue=hue, hue_style=hue_style, - ax=ax, cmap_params=cmap_params_subset, **kwargs) + primitive = plotfunc( + ds=ds, + x=x, + y=y, + hue=hue, + hue_style=hue_style, + ax=ax, + cmap_params=cmap_params_subset, + **kwargs + ) if _is_facetgrid: # if this was called from Facetgrid.map_dataset, - return primitive # finish here. Else, make labels - - if meta_data.get('xlabel', None): - ax.set_xlabel(meta_data.get('xlabel')) - if meta_data.get('ylabel', None): - ax.set_ylabel(meta_data.get('ylabel')) - - if meta_data['add_legend']: - ax.legend(handles=primitive, - labels=list(meta_data['hue'].values), - title=meta_data.get('hue_label', None)) - if meta_data['add_colorbar']: + return primitive # finish here. Else, make labels + + if meta_data.get("xlabel", None): + ax.set_xlabel(meta_data.get("xlabel")) + if meta_data.get("ylabel", None): + ax.set_ylabel(meta_data.get("ylabel")) + + if meta_data["add_legend"]: + ax.legend( + handles=primitive, + labels=list(meta_data["hue"].values), + title=meta_data.get("hue_label", None), + ) + if meta_data["add_colorbar"]: cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs - if 'label' not in cbar_kwargs: - cbar_kwargs['label'] = meta_data.get('hue_label', None) + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = meta_data.get("hue_label", None) _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) return primitive @functools.wraps(newplotfunc) - def plotmethod(_PlotMethods_obj, x=None, y=None, hue=None, - hue_style=None, col=None, row=None, ax=None, - figsize=None, - col_wrap=None, sharex=True, sharey=True, aspect=None, - size=None, subplot_kws=None, add_guide=None, - cbar_kwargs=None, cbar_ax=None, vmin=None, vmax=None, - norm=None, infer_intervals=None, center=None, levels=None, - robust=None, colors=None, extend=None, cmap=None, - **kwargs): + def plotmethod( + _PlotMethods_obj, + x=None, + y=None, + hue=None, + hue_style=None, + col=None, + row=None, + ax=None, + figsize=None, + col_wrap=None, + sharex=True, + sharey=True, + aspect=None, + size=None, + subplot_kws=None, + add_guide=None, + cbar_kwargs=None, + cbar_ax=None, + vmin=None, + vmax=None, + norm=None, + infer_intervals=None, + center=None, + levels=None, + robust=None, + colors=None, + extend=None, + cmap=None, + **kwargs + ): """ The method should have the same signature as the function. @@ -329,9 +389,9 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, hue=None, and passes all the other arguments straight through. """ allargs = locals() - allargs['ds'] = _PlotMethods_obj._ds + allargs["ds"] = _PlotMethods_obj._ds allargs.update(kwargs) - for arg in ['_PlotMethods_obj', 'newplotfunc', 'kwargs']: + for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: del allargs[arg] return newplotfunc(**allargs) @@ -347,43 +407,47 @@ def scatter(ds, x, y, ax, **kwargs): Scatter Dataset data variables against each other. """ - if 'add_colorbar' in kwargs or 'add_legend' in kwargs: - raise ValueError("Dataset.plot.scatter does not accept " - "'add_colorbar' or 'add_legend'. " - "Use 'add_guide' instead.") + if "add_colorbar" in kwargs or "add_legend" in kwargs: + raise ValueError( + "Dataset.plot.scatter does not accept " + "'add_colorbar' or 'add_legend'. " + "Use 'add_guide' instead." + ) - cmap_params = kwargs.pop('cmap_params') - hue = kwargs.pop('hue') - hue_style = kwargs.pop('hue_style') - markersize = kwargs.pop('markersize', None) - size_norm = kwargs.pop('size_norm', None) - size_mapping = kwargs.pop('size_mapping', None) # set by facetgrid + cmap_params = kwargs.pop("cmap_params") + hue = kwargs.pop("hue") + hue_style = kwargs.pop("hue_style") + markersize = kwargs.pop("markersize", None) + size_norm = kwargs.pop("size_norm", None) + size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid # need to infer size_mapping with full dataset - data = _infer_scatter_data(ds, x, y, hue, - markersize, size_norm, size_mapping) + data = _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping) - if hue_style == 'discrete': + if hue_style == "discrete": primitive = [] - for label in np.unique(data['hue'].values): - mask = data['hue'] == label - if data['sizes'] is not None: - kwargs.update( - s=data['sizes'].where(mask, drop=True).values.flatten()) + for label in np.unique(data["hue"].values): + mask = data["hue"] == label + if data["sizes"] is not None: + kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten()) primitive.append( - ax.scatter(data['x'].where(mask, drop=True).values.flatten(), - data['y'].where(mask, drop=True).values.flatten(), - label=label, **kwargs)) - - elif hue is None or hue_style == 'continuous': - if data['sizes'] is not None: - kwargs.update(s=data['sizes'].values.ravel()) - if data['hue'] is not None: - kwargs.update(c=data['hue'].values.ravel()) - - primitive = ax.scatter(data['x'].values.ravel(), - data['y'].values.ravel(), - **cmap_params, **kwargs) + ax.scatter( + data["x"].where(mask, drop=True).values.flatten(), + data["y"].where(mask, drop=True).values.flatten(), + label=label, + **kwargs + ) + ) + + elif hue is None or hue_style == "continuous": + if data["sizes"] is not None: + kwargs.update(s=data["sizes"].values.ravel()) + if data["hue"] is not None: + kwargs.update(c=data["hue"].values.ravel()) + + primitive = ax.scatter( + data["x"].values.ravel(), data["y"].values.ravel(), **cmap_params, **kwargs + ) return primitive diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index a28be7ce187..79f94077c8f 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -6,11 +6,15 @@ from ..core.formatting import format_item from .utils import ( - _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, - label_from_attrs) + _infer_xy_labels, + _process_cmap_cbar_kwargs, + import_matplotlib_pyplot, + label_from_attrs, +) + # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams -_FONTSIZE = 'small' +_FONTSIZE = "small" # For major ticks on x, y axes _NTICKS = 5 @@ -23,7 +27,7 @@ def _nicetitle(coord, value, maxchar, template): title = template.format(coord=coord, value=prettyvalue) if len(title) > maxchar: - title = title[:(maxchar - 3)] + '...' + title = title[: (maxchar - 3)] + "..." return title @@ -66,9 +70,19 @@ class FacetGrid: """ - def __init__(self, data, col=None, row=None, col_wrap=None, - sharex=True, sharey=True, figsize=None, aspect=1, size=3, - subplot_kws=None): + def __init__( + self, + data, + col=None, + row=None, + col_wrap=None, + sharex=True, + sharey=True, + figsize=None, + aspect=1, + size=3, + subplot_kws=None, + ): """ Parameters ---------- @@ -102,8 +116,10 @@ def __init__(self, data, col=None, row=None, col_wrap=None, rep_col = col is not None and not data[col].to_index().is_unique rep_row = row is not None and not data[row].to_index().is_unique if rep_col or rep_row: - raise ValueError('Coordinates used for faceting cannot ' - 'contain repeated (nonunique) values.') + raise ValueError( + "Coordinates used for faceting cannot " + "contain repeated (nonunique) values." + ) # single_group is the grouping variable, if there is exactly one if col and row: @@ -112,15 +128,13 @@ def __init__(self, data, col=None, row=None, col_wrap=None, ncol = len(data[col]) nfacet = nrow * ncol if col_wrap is not None: - warnings.warn('Ignoring col_wrap since both col and row ' - 'were passed') + warnings.warn("Ignoring col_wrap since both col and row " "were passed") elif row and not col: single_group = row elif not row and col: single_group = col else: - raise ValueError( - 'Pass a coordinate name as an argument for row or col') + raise ValueError("Pass a coordinate name as an argument for row or col") # Compute grid shape if single_group: @@ -144,17 +158,22 @@ def __init__(self, data, col=None, row=None, col_wrap=None, cbar_space = 1 figsize = (ncol * size * aspect + cbar_space, nrow * size) - fig, axes = plt.subplots(nrow, ncol, - sharex=sharex, sharey=sharey, squeeze=False, - figsize=figsize, subplot_kw=subplot_kws) + fig, axes = plt.subplots( + nrow, + ncol, + sharex=sharex, + sharey=sharey, + squeeze=False, + figsize=figsize, + subplot_kw=subplot_kws, + ) # Set up the lists of names for the row and column facet variables col_names = list(data[col].values) if col else [] row_names = list(data[row].values) if row else [] if single_group: - full = [{single_group: x} for x in - data[single_group].values] + full = [{single_group: x} for x in data[single_group].values] empty = [None for x in range(nrow * ncol - len(full))] name_dicts = full + empty else: @@ -218,26 +237,32 @@ def map_dataarray(self, func, x, y, **kwargs): """ - if kwargs.get('cbar_ax', None) is not None: - raise ValueError('cbar_ax not supported by FacetGrid.') + if kwargs.get("cbar_ax", None) is not None: + raise ValueError("cbar_ax not supported by FacetGrid.") cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - func, self.data.values, **kwargs) + func, self.data.values, **kwargs + ) - self._cmap_extend = cmap_params.get('extend') + self._cmap_extend = cmap_params.get("extend") # Order is important func_kwargs = { - k: v for k, v in kwargs.items() - if k not in {'cmap', 'colors', 'cbar_kwargs', 'levels'} + k: v + for k, v in kwargs.items() + if k not in {"cmap", "colors", "cbar_kwargs", "levels"} } func_kwargs.update(cmap_params) - func_kwargs.update({'add_colorbar': False, 'add_labels': False}) + func_kwargs.update({"add_colorbar": False, "add_labels": False}) # Get x, y labels for the first subplot x, y = _infer_xy_labels( - darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, - imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None)) + darray=self.data.loc[self.name_dicts.flat[0]], + x=x, + y=y, + imshow=func.__name__ == "imshow", + rgb=kwargs.get("rgb", None), + ) for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value @@ -248,26 +273,35 @@ def map_dataarray(self, func, x, y, **kwargs): self._finalize_grid(x, y) - if kwargs.get('add_colorbar', True): + if kwargs.get("add_colorbar", True): self.add_colorbar(**cbar_kwargs) return self - def map_dataarray_line(self, func, x, y, hue, add_legend=True, - _labels=None, **kwargs): + def map_dataarray_line( + self, func, x, y, hue, add_legend=True, _labels=None, **kwargs + ): from .plot import _infer_line_data for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: subset = self.data.loc[d] - mappable = func(subset, x=x, y=y, ax=ax, - hue=hue, add_legend=False, _labels=False, - **kwargs) + mappable = func( + subset, + x=x, + y=y, + ax=ax, + hue=hue, + add_legend=False, + _labels=False, + **kwargs + ) self._mappables.append(mappable) _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( - darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue) + darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue + ) self._hue_var = hueplt self._hue_label = huelabel @@ -278,47 +312,48 @@ def map_dataarray_line(self, func, x, y, hue, add_legend=True, return self - def map_dataset(self, func, x=None, y=None, hue=None, hue_style=None, - add_guide=None, **kwargs): + def map_dataset( + self, func, x=None, y=None, hue=None, hue_style=None, add_guide=None, **kwargs + ): from .dataset_plot import _infer_meta_data, _parse_size - kwargs['add_guide'] = False - kwargs['_is_facetgrid'] = True + kwargs["add_guide"] = False + kwargs["_is_facetgrid"] = True - if kwargs.get('markersize', None): - kwargs['size_mapping'] = _parse_size( - self.data[kwargs['markersize']], - kwargs.pop('size_norm', None)) + if kwargs.get("markersize", None): + kwargs["size_mapping"] = _parse_size( + self.data[kwargs["markersize"]], kwargs.pop("size_norm", None) + ) - meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, - add_guide) - kwargs['meta_data'] = meta_data + meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, add_guide) + kwargs["meta_data"] = meta_data - if hue and meta_data['hue_style'] == 'continuous': + if hue and meta_data["hue_style"] == "continuous": cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - func, self.data[hue].values, **kwargs) - kwargs['meta_data']['cmap_params'] = cmap_params - kwargs['meta_data']['cbar_kwargs'] = cbar_kwargs + func, self.data[hue].values, **kwargs + ) + kwargs["meta_data"]["cmap_params"] = cmap_params + kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: subset = self.data.loc[d] - maybe_mappable = func(ds=subset, x=x, y=y, - hue=hue, hue_style=hue_style, - ax=ax, **kwargs) + maybe_mappable = func( + ds=subset, x=x, y=y, hue=hue, hue_style=hue_style, ax=ax, **kwargs + ) # TODO: this is needed to get legends to work. # but maybe_mappable is a list in that case :/ self._mappables.append(maybe_mappable) - self._finalize_grid(meta_data['xlabel'], meta_data['ylabel']) + self._finalize_grid(meta_data["xlabel"], meta_data["ylabel"]) if hue: - self._hue_label = meta_data.pop('hue_label', None) - if meta_data['add_legend']: - self._hue_var = meta_data['hue'] + self._hue_label = meta_data.pop("hue_label", None) + if meta_data["add_legend"]: + self._hue_var = meta_data["hue"] self.add_legend() - elif meta_data['add_colorbar']: + elif meta_data["add_colorbar"]: self.add_colorbar(label=self._hue_label, **cbar_kwargs) return self @@ -341,7 +376,9 @@ def add_legend(self, **kwargs): handles=self._mappables[-1], labels=list(self._hue_var.values), title=self._hue_label, - loc="center right", **kwargs) + loc="center right", + **kwargs + ) self.figlegend = figlegend # Draw the plot to set the bounding boxes correctly @@ -370,12 +407,12 @@ def add_colorbar(self, **kwargs): """ kwargs = kwargs.copy() if self._cmap_extend is not None: - kwargs.setdefault('extend', self._cmap_extend) - if 'label' not in kwargs: - kwargs.setdefault('label', label_from_attrs(self.data)) - self.cbar = self.fig.colorbar(self._mappables[-1], - ax=list(self.axes.flat), - **kwargs) + kwargs.setdefault("extend", self._cmap_extend) + if "label" not in kwargs: + kwargs.setdefault("label", label_from_attrs(self.data)) + self.cbar = self.fig.colorbar( + self._mappables[-1], ax=list(self.axes.flat), **kwargs + ) return self def set_axis_labels(self, x_var=None, y_var=None): @@ -412,8 +449,7 @@ def set_ylabels(self, label=None, **kwargs): ax.set_ylabel(label, **kwargs) return self - def set_titles(self, template="{coord} = {value}", maxchar=30, - size=None, **kwargs): + def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs): """ Draw titles either above each facet or on the grid margins. @@ -436,8 +472,7 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, if size is None: size = mpl.rcParams["axes.labelsize"] - nicetitle = functools.partial(_nicetitle, maxchar=maxchar, - template=template) + nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) if self._single_group: for d, ax in zip(self.name_dicts.flat, self.axes.flat): @@ -449,21 +484,25 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, else: # The row titles on the right edge of the grid for ax, row_name in zip(self.axes[:, -1], self.row_names): - title = nicetitle(coord=self._row_var, value=row_name, - maxchar=maxchar) - ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction", - rotation=270, ha="left", va="center", **kwargs) + title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar) + ax.annotate( + title, + xy=(1.02, 0.5), + xycoords="axes fraction", + rotation=270, + ha="left", + va="center", + **kwargs + ) # The column titles on the top row for ax, col_name in zip(self.axes[0, :], self.col_names): - title = nicetitle(coord=self._col_var, value=col_name, - maxchar=maxchar) + title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar) ax.set_title(title, size=size, **kwargs) return self - def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, - fontsize=_FONTSIZE): + def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, fontsize=_FONTSIZE): """ Set and control tick behavior @@ -488,8 +527,9 @@ def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, for ax in self.axes.flat: ax.xaxis.set_major_locator(x_major_locator) ax.yaxis.set_major_locator(y_major_locator) - for tick in itertools.chain(ax.xaxis.get_major_ticks(), - ax.yaxis.get_major_ticks()): + for tick in itertools.chain( + ax.xaxis.get_major_ticks(), ax.yaxis.get_major_ticks() + ): tick.label1.set_fontsize(fontsize) return self @@ -527,8 +567,7 @@ def map(self, func, *args, **kwargs): maybe_mappable = func(*innerargs, **kwargs) # TODO: better way to verify that an artist is mappable? # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 - if (maybe_mappable and - hasattr(maybe_mappable, 'autoscale_None')): + if maybe_mappable and hasattr(maybe_mappable, "autoscale_None"): self._mappables.append(maybe_mappable) self._finalize_grid(*args[:2]) @@ -536,10 +575,24 @@ def map(self, func, *args, **kwargs): return self -def _easy_facetgrid(data, plotfunc, kind, x=None, y=None, row=None, - col=None, col_wrap=None, sharex=True, sharey=True, - aspect=None, size=None, subplot_kws=None, ax=None, - figsize=None, **kwargs): +def _easy_facetgrid( + data, + plotfunc, + kind, + x=None, + y=None, + row=None, + col=None, + col_wrap=None, + sharex=True, + sharey=True, + aspect=None, + size=None, + subplot_kws=None, + ax=None, + figsize=None, + **kwargs +): """ Convenience method to call xarray.plot.FacetGrid from 2d plotting methods @@ -552,17 +605,26 @@ def _easy_facetgrid(data, plotfunc, kind, x=None, y=None, row=None, if size is None: size = 3 elif figsize is not None: - raise ValueError('cannot provide both `figsize` and `size` arguments') - - g = FacetGrid(data=data, col=col, row=row, col_wrap=col_wrap, - sharex=sharex, sharey=sharey, figsize=figsize, - aspect=aspect, size=size, subplot_kws=subplot_kws) - - if kind == 'line': + raise ValueError("cannot provide both `figsize` and `size` arguments") + + g = FacetGrid( + data=data, + col=col, + row=row, + col_wrap=col_wrap, + sharex=sharex, + sharey=sharey, + figsize=figsize, + aspect=aspect, + size=size, + subplot_kws=subplot_kws, + ) + + if kind == "line": return g.map_dataarray_line(plotfunc, x, y, **kwargs) - if kind == 'dataarray': + if kind == "dataarray": return g.map_dataarray(plotfunc, x, y, **kwargs) - if kind == 'dataset': + if kind == "dataset": return g.map_dataset(plotfunc, x, y, **kwargs) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 34cb56f54e0..14f03d42fe7 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -13,32 +13,42 @@ from .facetgrid import _easy_facetgrid from .utils import ( - _add_colorbar, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, - _interval_to_double_bound_points, _interval_to_mid_points, - _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_2dplot, - _update_axes, _valid_other_type, get_axis, import_matplotlib_pyplot, - label_from_attrs) + _add_colorbar, + _ensure_plottable, + _infer_interval_breaks, + _infer_xy_labels, + _interval_to_double_bound_points, + _interval_to_mid_points, + _process_cmap_cbar_kwargs, + _rescale_imshow_rgb, + _resolve_intervals_2dplot, + _update_axes, + _valid_other_type, + get_axis, + import_matplotlib_pyplot, + label_from_attrs, +) def _infer_line_data(darray, x, y, hue): - error_msg = ('must be either None or one of ({:s})' - .format(', '.join([repr(dd) for dd in darray.dims]))) + error_msg = "must be either None or one of ({:s})".format( + ", ".join([repr(dd) for dd in darray.dims]) + ) ndims = len(darray.dims) if x is not None and x not in darray.dims and x not in darray.coords: - raise ValueError('x ' + error_msg) + raise ValueError("x " + error_msg) if y is not None and y not in darray.dims and y not in darray.coords: - raise ValueError('y ' + error_msg) + raise ValueError("y " + error_msg) if x is not None and y is not None: - raise ValueError('You cannot specify both x and y kwargs' - 'for line plots.') + raise ValueError("You cannot specify both x and y kwargs" "for line plots.") if ndims == 1: huename = None hueplt = None - huelabel = '' + huelabel = "" if x is not None: xplt = darray[x] @@ -55,8 +65,7 @@ def _infer_line_data(darray, x, y, hue): else: if x is None and y is None and hue is None: - raise ValueError('For 2D inputs, please' - 'specify either hue, x or y.') + raise ValueError("For 2D inputs, please" "specify either hue, x or y.") if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) @@ -65,13 +74,13 @@ def _infer_line_data(darray, x, y, hue): if huename in darray.dims: otherindex = 1 if darray.dims.index(huename) == 0 else 0 otherdim = darray.dims[otherindex] - yplt = darray.transpose( - otherdim, huename, transpose_coords=False) - xplt = xplt.transpose( - otherdim, huename, transpose_coords=False) + yplt = darray.transpose(otherdim, huename, transpose_coords=False) + xplt = xplt.transpose(otherdim, huename, transpose_coords=False) else: - raise ValueError('For 2D inputs, hue must be a dimension' - ' i.e. one of ' + repr(darray.dims)) + raise ValueError( + "For 2D inputs, hue must be a dimension" + " i.e. one of " + repr(darray.dims) + ) else: yplt = darray.transpose(xname, huename) @@ -83,11 +92,12 @@ def _infer_line_data(darray, x, y, hue): if huename in darray.dims: otherindex = 1 if darray.dims.index(huename) == 0 else 0 otherdim = darray.dims[otherindex] - xplt = darray.transpose( - otherdim, huename, transpose_coords=False) + xplt = darray.transpose(otherdim, huename, transpose_coords=False) else: - raise ValueError('For 2D inputs, hue must be a dimension' - ' i.e. one of ' + repr(darray.dims)) + raise ValueError( + "For 2D inputs, hue must be a dimension" + " i.e. one of " + repr(darray.dims) + ) else: xplt = darray.transpose(yname, huename) @@ -101,8 +111,17 @@ def _infer_line_data(darray, x, y, hue): return xplt, yplt, hueplt, xlabel, ylabel, huelabel -def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, - rtol=0.01, subplot_kws=None, **kwargs): +def plot( + darray, + row=None, + col=None, + col_wrap=None, + ax=None, + hue=None, + rtol=0.01, + subplot_kws=None, + **kwargs +): """ Default plot of DataArray using matplotlib.pyplot. @@ -149,22 +168,24 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, ndims = len(plot_dims) - error_msg = ('Only 1d and 2d plots are supported for facets in xarray. ' - 'See the package `Seaborn` for more options.') + error_msg = ( + "Only 1d and 2d plots are supported for facets in xarray. " + "See the package `Seaborn` for more options." + ) if ndims in [1, 2]: if row or col: - kwargs['row'] = row - kwargs['col'] = col - kwargs['col_wrap'] = col_wrap - kwargs['subplot_kws'] = subplot_kws + kwargs["row"] = row + kwargs["col"] = col + kwargs["col_wrap"] = col_wrap + kwargs["subplot_kws"] = subplot_kws if ndims == 1: plotfunc = line - kwargs['hue'] = hue + kwargs["hue"] = hue elif ndims == 2: if hue: plotfunc = line - kwargs['hue'] = hue + kwargs["hue"] = hue else: plotfunc = pcolormesh else: @@ -172,17 +193,37 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, raise ValueError(error_msg) plotfunc = hist - kwargs['ax'] = ax + kwargs["ax"] = ax return plotfunc(darray, **kwargs) # This function signature should not change so that it can use # matplotlib format strings -def line(darray, *args, row=None, col=None, figsize=None, aspect=None, - size=None, ax=None, hue=None, x=None, y=None, xincrease=None, - yincrease=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, add_legend=True, _labels=True, **kwargs): +def line( + darray, + *args, + row=None, + col=None, + figsize=None, + aspect=None, + size=None, + ax=None, + hue=None, + x=None, + y=None, + xincrease=None, + yincrease=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + add_legend=True, + _labels=True, + **kwargs +): """ Line plot of DataArray index against values @@ -230,43 +271,47 @@ def line(darray, *args, row=None, col=None, figsize=None, aspect=None, # Handle facetgrids first if row or col: allargs = locals().copy() - allargs.update(allargs.pop('kwargs')) - allargs.pop('darray') - return _easy_facetgrid(darray, line, kind='line', **allargs) + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + return _easy_facetgrid(darray, line, kind="line", **allargs) ndims = len(darray.dims) if ndims > 2: - raise ValueError('Line plots are for 1- or 2-dimensional DataArrays. ' - 'Passed DataArray has {ndims} ' - 'dimensions'.format(ndims=ndims)) + raise ValueError( + "Line plots are for 1- or 2-dimensional DataArrays. " + "Passed DataArray has {ndims} " + "dimensions".format(ndims=ndims) + ) # The allargs dict passed to _easy_facetgrid above contains args if args is (): - args = kwargs.pop('args', ()) + args = kwargs.pop("args", ()) else: - assert 'args' not in kwargs + assert "args" not in kwargs ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, xlabel, ylabel, hue_label = \ - _infer_line_data(darray, x, y, hue) + xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue) # Remove pd.Intervals if contained in xplt.values. if _valid_other_type(xplt.values, [pd.Interval]): # Is it a step plot? (see matplotlib.Axes.step) - if kwargs.get('linestyle', '').startswith('steps-'): - xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, - yplt.values) + if kwargs.get("linestyle", "").startswith("steps-"): + xplt_val, yplt_val = _interval_to_double_bound_points( + xplt.values, yplt.values + ) # Remove steps-* to be sure that matplotlib is not confused - kwargs['linestyle'] = (kwargs['linestyle'] - .replace('steps-pre', '') - .replace('steps-post', '') - .replace('steps-mid', '')) - if kwargs['linestyle'] == '': - del kwargs['linestyle'] + kwargs["linestyle"] = ( + kwargs["linestyle"] + .replace("steps-pre", "") + .replace("steps-post", "") + .replace("steps-mid", "") + ) + if kwargs["linestyle"] == "": + del kwargs["linestyle"] else: xplt_val = _interval_to_mid_points(xplt.values) yplt_val = yplt.values - xlabel += '_center' + xlabel += "_center" else: xplt_val = xplt.values yplt_val = yplt.values @@ -285,9 +330,7 @@ def line(darray, *args, row=None, col=None, figsize=None, aspect=None, ax.set_title(darray._title_for_slice()) if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, - labels=list(hueplt.values), - title=hue_label) + ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label) # Rotate dates on xlabels # Do this without calling autofmt_xdate so that x-axes ticks @@ -296,15 +339,14 @@ def line(darray, *args, row=None, col=None, figsize=None, aspect=None, if np.issubdtype(xplt.dtype, np.datetime64): for xlabels in ax.get_xticklabels(): xlabels.set_rotation(30) - xlabels.set_ha('right') + xlabels.set_ha("right") - _update_axes(ax, xincrease, yincrease, xscale, yscale, - xticks, yticks, xlim, ylim) + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) return primitive -def step(darray, *args, where='pre', linestyle=None, ls=None, **kwargs): +def step(darray, *args, where="pre", linestyle=None, ls=None, **kwargs): """ Step plot of DataArray index against values @@ -329,25 +371,37 @@ def step(darray, *args, where='pre', linestyle=None, ls=None, **kwargs): *args, **kwargs : optional Additional arguments following :py:func:`xarray.plot.line` """ - if where not in {'pre', 'post', 'mid'}: - raise ValueError("'where' argument to step must be " - "'pre', 'post' or 'mid'") + if where not in {"pre", "post", "mid"}: + raise ValueError("'where' argument to step must be " "'pre', 'post' or 'mid'") if ls is not None: if linestyle is None: linestyle = ls else: - raise TypeError('ls and linestyle are mutually exclusive') + raise TypeError("ls and linestyle are mutually exclusive") if linestyle is None: - linestyle = '' - linestyle = 'steps-' + where + linestyle + linestyle = "" + linestyle = "steps-" + where + linestyle return line(darray, *args, linestyle=linestyle, **kwargs) -def hist(darray, figsize=None, size=None, aspect=None, ax=None, - xincrease=None, yincrease=None, xscale=None, yscale=None, - xticks=None, yticks=None, xlim=None, ylim=None, **kwargs): +def hist( + darray, + figsize=None, + size=None, + aspect=None, + ax=None, + xincrease=None, + yincrease=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + **kwargs +): """ Histogram of DataArray @@ -382,11 +436,10 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, primitive = ax.hist(no_nan, **kwargs) - ax.set_title('Histogram') + ax.set_title("Histogram") ax.set_xlabel(label_from_attrs(darray)) - _update_axes(ax, xincrease, yincrease, xscale, yscale, - xticks, yticks, xlim, ylim) + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) return primitive @@ -525,27 +578,54 @@ def _plot2d(plotfunc): """ # Build on the original docstring - plotfunc.__doc__ = '%s\n%s' % (plotfunc.__doc__, commondoc) + plotfunc.__doc__ = "%s\n%s" % (plotfunc.__doc__, commondoc) @functools.wraps(plotfunc) - def newplotfunc(darray, x=None, y=None, figsize=None, size=None, - aspect=None, ax=None, row=None, col=None, - col_wrap=None, xincrease=True, yincrease=True, - add_colorbar=None, add_labels=True, vmin=None, vmax=None, - cmap=None, center=None, robust=False, extend=None, - levels=None, infer_intervals=None, colors=None, - subplot_kws=None, cbar_ax=None, cbar_kwargs=None, - xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, norm=None, **kwargs): + def newplotfunc( + darray, + x=None, + y=None, + figsize=None, + size=None, + aspect=None, + ax=None, + row=None, + col=None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_colorbar=None, + add_labels=True, + vmin=None, + vmax=None, + cmap=None, + center=None, + robust=False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws=None, + cbar_ax=None, + cbar_kwargs=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + norm=None, + **kwargs + ): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. # Decide on a default for the colorbar before facetgrids if add_colorbar is None: - add_colorbar = plotfunc.__name__ != 'contour' - imshow_rgb = ( - plotfunc.__name__ == 'imshow' and - darray.ndim == (3 + (row is not None) + (col is not None))) + add_colorbar = plotfunc.__name__ != "contour" + imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( + 3 + (row is not None) + (col is not None) + ) if imshow_rgb: # Don't add a colorbar when showing an image with explicit colors add_colorbar = False @@ -558,24 +638,27 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Handle facetgrids first if row or col: allargs = locals().copy() - del allargs['darray'] - del allargs['imshow_rgb'] - allargs.update(allargs.pop('kwargs')) + del allargs["darray"] + del allargs["imshow_rgb"] + allargs.update(allargs.pop("kwargs")) # Need the decorated plotting function - allargs['plotfunc'] = globals()[plotfunc.__name__] - return _easy_facetgrid(darray, kind='dataarray', **allargs) + allargs["plotfunc"] = globals()[plotfunc.__name__] + return _easy_facetgrid(darray, kind="dataarray", **allargs) plt = import_matplotlib_pyplot() - rgb = kwargs.pop('rgb', None) - if rgb is not None and plotfunc.__name__ != 'imshow': + rgb = kwargs.pop("rgb", None) + if rgb is not None and plotfunc.__name__ != "imshow": raise ValueError('The "rgb" keyword is only valid for imshow()') elif rgb is not None and not imshow_rgb: - raise ValueError('The "rgb" keyword is only valid for imshow()' - 'with a three-dimensional array (per facet)') + raise ValueError( + 'The "rgb" keyword is only valid for imshow()' + "with a three-dimensional array (per facet)" + ) xlab, ylab = _infer_xy_labels( - darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb) + darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb + ) # better to pass the ndarrays directly to plotting functions xval = darray[xlab].values @@ -611,34 +694,42 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, _ensure_plottable(xplt, yplt) cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - plotfunc, zval.data, **locals()) + plotfunc, zval.data, **locals() + ) - if 'contour' in plotfunc.__name__: + if "contour" in plotfunc.__name__: # extend is a keyword argument only for contour and contourf, but # passing it to the colorbar is sufficient for imshow and # pcolormesh - kwargs['extend'] = cmap_params['extend'] - kwargs['levels'] = cmap_params['levels'] + kwargs["extend"] = cmap_params["extend"] + kwargs["levels"] = cmap_params["levels"] # if colors == a single color, matplotlib draws dashed negative # contours. we lose this feature if we pass cmap and not colors if isinstance(colors, str): - cmap_params['cmap'] = None - kwargs['colors'] = colors + cmap_params["cmap"] = None + kwargs["colors"] = colors - if 'pcolormesh' == plotfunc.__name__: - kwargs['infer_intervals'] = infer_intervals + if "pcolormesh" == plotfunc.__name__: + kwargs["infer_intervals"] = infer_intervals - if 'imshow' == plotfunc.__name__ and isinstance(aspect, str): + if "imshow" == plotfunc.__name__ and isinstance(aspect, str): # forbid usage of mpl strings - raise ValueError("plt.imshow's `aspect` kwarg is not available " - "in xarray") + raise ValueError( + "plt.imshow's `aspect` kwarg is not available " "in xarray" + ) ax = get_axis(figsize, size, aspect, ax) - primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], - vmin=cmap_params['vmin'], - vmax=cmap_params['vmax'], - norm=cmap_params['norm'], - **kwargs) + primitive = plotfunc( + xplt, + yplt, + zval, + ax=ax, + cmap=cmap_params["cmap"], + vmin=cmap_params["vmin"], + vmax=cmap_params["vmax"], + norm=cmap_params["norm"], + **kwargs + ) # Label the plot with metadata if add_labels: @@ -647,21 +738,22 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, ax.set_title(darray._title_for_slice()) if add_colorbar: - if add_labels and 'label' not in cbar_kwargs: - cbar_kwargs['label'] = label_from_attrs(darray) - cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, - cmap_params) + if add_labels and "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(darray) + cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) elif cbar_ax is not None or cbar_kwargs: # inform the user about keywords which aren't used - raise ValueError("cbar_ax and cbar_kwargs can't be used with " - "add_colorbar=False.") + raise ValueError( + "cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False." + ) # origin kwarg overrides yincrease - if 'origin' in kwargs: + if "origin" in kwargs: yincrease = None - _update_axes(ax, xincrease, yincrease, xscale, yscale, - xticks, yticks, xlim, ylim) + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) # Rotate dates on xlabels # Do this without calling autofmt_xdate so that x-axes ticks @@ -670,21 +762,48 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, if np.issubdtype(xplt.dtype, np.datetime64): for xlabels in ax.get_xticklabels(): xlabels.set_rotation(30) - xlabels.set_ha('right') + xlabels.set_ha("right") return primitive # For use as DataArray.plot.plotmethod @functools.wraps(newplotfunc) - def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, - aspect=None, ax=None, row=None, col=None, col_wrap=None, - xincrease=True, yincrease=True, add_colorbar=None, - add_labels=True, vmin=None, vmax=None, cmap=None, - colors=None, center=None, robust=False, extend=None, - levels=None, infer_intervals=None, subplot_kws=None, - cbar_ax=None, cbar_kwargs=None, - xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, norm=None, **kwargs): + def plotmethod( + _PlotMethods_obj, + x=None, + y=None, + figsize=None, + size=None, + aspect=None, + ax=None, + row=None, + col=None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_colorbar=None, + add_labels=True, + vmin=None, + vmax=None, + cmap=None, + colors=None, + center=None, + robust=False, + extend=None, + levels=None, + infer_intervals=None, + subplot_kws=None, + cbar_ax=None, + cbar_kwargs=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + norm=None, + **kwargs + ): """ The method should have the same signature as the function. @@ -692,9 +811,9 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, and passes all the other arguments straight through. """ allargs = locals() - allargs['darray'] = _PlotMethods_obj._da + allargs["darray"] = _PlotMethods_obj._da allargs.update(kwargs) - for arg in ['_PlotMethods_obj', 'newplotfunc', 'kwargs']: + for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: del allargs[arg] return newplotfunc(**allargs) @@ -730,36 +849,36 @@ def imshow(x, y, z, ax, **kwargs): """ if x.ndim != 1 or y.ndim != 1: - raise ValueError('imshow requires 1D coordinates, try using ' - 'pcolormesh or contour(f)') + raise ValueError( + "imshow requires 1D coordinates, try using " "pcolormesh or contour(f)" + ) # Centering the pixels- Assumes uniform spacing try: xstep = (x[1] - x[0]) / 2.0 except IndexError: # Arbitrary default value, similar to matplotlib behaviour - xstep = .1 + xstep = 0.1 try: ystep = (y[1] - y[0]) / 2.0 except IndexError: - ystep = .1 + ystep = 0.1 left, right = x[0] - xstep, x[-1] + xstep bottom, top = y[-1] + ystep, y[0] - ystep - defaults = {'origin': 'upper', - 'interpolation': 'nearest'} + defaults = {"origin": "upper", "interpolation": "nearest"} - if not hasattr(ax, 'projection'): + if not hasattr(ax, "projection"): # not for cartopy geoaxes - defaults['aspect'] = 'auto' + defaults["aspect"] = "auto" # Allow user to override these defaults defaults.update(kwargs) - if defaults['origin'] == 'upper': - defaults['extent'] = [left, right, bottom, top] + if defaults["origin"] == "upper": + defaults["extent"] = [left, right, bottom, top] else: - defaults['extent'] = [left, right, top, bottom] + defaults["extent"] = [left, right, top, bottom] if z.ndim == 3: # matplotlib imshow uses black for missing data, but Xarray makes @@ -812,7 +931,7 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): # decide on a default for infer_intervals (GH781) x = np.asarray(x) if infer_intervals is None: - if hasattr(ax, 'projection'): + if hasattr(ax, "projection"): if len(x.shape) == 1: infer_intervals = True else: @@ -820,9 +939,10 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): else: infer_intervals = True - if (infer_intervals and - ((np.shape(x)[0] == np.shape(z)[1]) or - ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))): + if infer_intervals and ( + (np.shape(x)[0] == np.shape(z)[1]) + or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + ): if len(x.shape) == 1: x = _infer_interval_breaks(x, check_monotonic=True) else: @@ -830,8 +950,7 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): x = _infer_interval_breaks(x, axis=1) x = _infer_interval_breaks(x, axis=0) - if (infer_intervals and - (np.shape(y)[0] == np.shape(z)[0])): + if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]): if len(y.shape) == 1: y = _infer_interval_breaks(y, check_monotonic=True) else: @@ -843,7 +962,7 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): # by default, pcolormesh picks "round" values for bounds # this results in ugly looking plots with lots of surrounding whitespace - if not hasattr(ax, 'projection') and x.ndim == 1 and y.ndim == 1: + if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1: # not a cartopy geoaxis ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 23789d0cbb0..2d50734f519 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -14,7 +14,8 @@ try: import nc_time_axis - if LooseVersion(nc_time_axis.__version__) < LooseVersion('1.2.0'): + + if LooseVersion(nc_time_axis.__version__) < LooseVersion("1.2.0"): nc_time_axis_available = False else: nc_time_axis_available = True @@ -25,13 +26,17 @@ def import_seaborn(): - '''import seaborn and handle deprecation of apionly module''' + """import seaborn and handle deprecation of apionly module""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") try: import seaborn.apionly as sns - if (w and issubclass(w[-1].category, UserWarning) and - ("seaborn.apionly module" in str(w[-1].message))): + + if ( + w + and issubclass(w[-1].category, UserWarning) + and ("seaborn.apionly module" in str(w[-1].message)) + ): raise ImportError except ImportError: import seaborn as sns @@ -49,10 +54,12 @@ def register_pandas_datetime_converter_if_needed(): if not _registered: try: from pandas.plotting import register_matplotlib_converters + register_matplotlib_converters() except ImportError: # register_matplotlib_converters new in pandas 0.22 from pandas.tseries import converter + converter.register() _registered = True @@ -61,6 +68,7 @@ def import_matplotlib_pyplot(): """Import pyplot as register appropriate converters.""" register_pandas_datetime_converter_if_needed() import matplotlib.pyplot as plt + return plt @@ -68,13 +76,13 @@ def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax if extend_min and extend_max: - extend = 'both' + extend = "both" elif extend_min: - extend = 'min' + extend = "min" elif extend_max: - extend = 'max' + extend = "max" else: - extend = 'neither' + extend = "neither" return extend @@ -86,11 +94,11 @@ def _build_discrete_cmap(cmap, levels, extend, filled): if not filled: # non-filled contour plots - extend = 'max' + extend = "max" - if extend == 'both': + if extend == "both": ext_n = 2 - elif extend in ['min', 'max']: + elif extend in ["min", "max"]: ext_n = 1 else: ext_n = 0 @@ -98,10 +106,9 @@ def _build_discrete_cmap(cmap, levels, extend, filled): n_colors = len(levels) + ext_n - 1 pal = _color_palette(cmap, n_colors) - new_cmap, cnorm = mpl.colors.from_levels_and_colors( - levels, pal, extend=extend) + new_cmap, cnorm = mpl.colors.from_levels_and_colors(levels, pal, extend=extend) # copy the old cmap name, for easier testing - new_cmap.name = getattr(cmap, 'name', cmap) + new_cmap.name = getattr(cmap, "name", cmap) return new_cmap, cnorm @@ -109,7 +116,8 @@ def _build_discrete_cmap(cmap, levels, extend, filled): def _color_palette(cmap, n_colors): import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap - colors_i = np.linspace(0, 1., n_colors) + + colors_i = np.linspace(0, 1.0, n_colors) if isinstance(cmap, (list, tuple)): # we have a list of colors cmap = ListedColormap(cmap, N=n_colors) @@ -124,6 +132,7 @@ def _color_palette(cmap, n_colors): # ValueError happens when mpl doesn't like a colormap, try seaborn try: from seaborn.apionly import color_palette + pal = color_palette(cmap, n_colors=n_colors) except (ValueError, ImportError): # or maybe we just got a single color as a string @@ -140,9 +149,19 @@ def _color_palette(cmap, n_colors): # https://github.com/mwaskom/seaborn/blob/v0.6/seaborn/matrix.py#L158 # Used under the terms of Seaborn's license, see licenses/SEABORN_LICENSE. -def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, - center=None, robust=False, extend=None, - levels=None, filled=True, norm=None): + +def _determine_cmap_params( + plot_data, + vmin=None, + vmax=None, + cmap=None, + center=None, + robust=False, + extend=None, + levels=None, + filled=True, + norm=None, +): """ Use some heuristics to set good defaults for colorbar and range. @@ -226,16 +245,18 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, norm.vmin = vmin else: if not vmin_was_none and vmin != norm.vmin: - raise ValueError('Cannot supply vmin and a norm' - + ' with a different vmin.') + raise ValueError( + "Cannot supply vmin and a norm" + " with a different vmin." + ) vmin = norm.vmin if norm.vmax is None: norm.vmax = vmax else: if not vmax_was_none and vmax != norm.vmax: - raise ValueError('Cannot supply vmax and a norm' - + ' with a different vmax.') + raise ValueError( + "Cannot supply vmax and a norm" + " with a different vmax." + ) vmax = norm.vmax # if BoundaryNorm, then set levels @@ -245,9 +266,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # Choose default colormaps if not provided if cmap is None: if divergent: - cmap = OPTIONS['cmap_divergent'] + cmap = OPTIONS["cmap_divergent"] else: - cmap = OPTIONS['cmap_sequential'] + cmap = OPTIONS["cmap_sequential"] # Handle discrete levels if levels is not None and norm is None: @@ -269,8 +290,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) norm = newnorm if norm is None else norm - return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, - levels=levels, norm=norm) + return dict( + vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm + ) def _infer_xy_labels_3d(darray, x, y, rgb): @@ -287,25 +309,32 @@ def _infer_xy_labels_3d(darray, x, y, rgb): not_none = [a for a in (x, y, rgb) if a is not None] if len(set(not_none)) < len(not_none): raise ValueError( - 'Dimension names must be None or unique strings, but imshow was ' - 'passed x=%r, y=%r, and rgb=%r.' % (x, y, rgb)) + "Dimension names must be None or unique strings, but imshow was " + "passed x=%r, y=%r, and rgb=%r." % (x, y, rgb) + ) for label in not_none: if label not in darray.dims: - raise ValueError('%r is not a dimension' % (label,)) + raise ValueError("%r is not a dimension" % (label,)) # Then calculate rgb dimension if certain and check validity - could_be_color = [label for label in darray.dims - if darray[label].size in (3, 4) and label not in (x, y)] + could_be_color = [ + label + for label in darray.dims + if darray[label].size in (3, 4) and label not in (x, y) + ] if rgb is None and not could_be_color: raise ValueError( - 'A 3-dimensional array was passed to imshow(), but there is no ' - 'dimension that could be color. At least one dimension must be ' - 'of size 3 (RGB) or 4 (RGBA), and not given as x or y.') + "A 3-dimensional array was passed to imshow(), but there is no " + "dimension that could be color. At least one dimension must be " + "of size 3 (RGB) or 4 (RGBA), and not given as x or y." + ) if rgb is None and len(could_be_color) == 1: rgb = could_be_color[0] if rgb is not None and darray[rgb].size not in (3, 4): - raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.' - % (rgb, darray[rgb].size)) + raise ValueError( + "Cannot interpret dim %r of size %s as RGB or RGBA." + % (rgb, darray[rgb].size) + ) # If rgb dimension is still unknown, there must be two or three dimensions # in could_be_color. We therefore warn, and use a heuristic to break ties. @@ -313,10 +342,11 @@ def _infer_xy_labels_3d(darray, x, y, rgb): assert len(could_be_color) in (2, 3) rgb = could_be_color[-1] warnings.warn( - 'Several dimensions of this array could be colors. Xarray ' - 'will use the last possible dimension (%r) to match ' - 'matplotlib.pyplot.imshow. You can pass names of x, y, ' - 'and/or rgb dimensions to override this guess.' % rgb) + "Several dimensions of this array could be colors. Xarray " + "will use the last possible dimension (%r) to match " + "matplotlib.pyplot.imshow. You can pass names of x, y, " + "and/or rgb dimensions to override this guess." % rgb + ) assert rgb is not None # Finally, we pick out the red slice and delegate to the 2D version: @@ -335,18 +365,18 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): if x is None and y is None: if darray.ndim != 2: - raise ValueError('DataArray must be 2d') + raise ValueError("DataArray must be 2d") y, x = darray.dims elif x is None: if y not in darray.dims and y not in darray.coords: - raise ValueError('y must be a dimension name if x is not supplied') + raise ValueError("y must be a dimension name if x is not supplied") x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] elif y is None: if x not in darray.dims and x not in darray.coords: - raise ValueError('x must be a dimension name if y is not supplied') + raise ValueError("x must be a dimension name if y is not supplied") y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] elif any(k not in darray.coords and k not in darray.dims for k in (x, y)): - raise ValueError('x and y must be coordinate variables') + raise ValueError("x and y must be coordinate variables") return x, y @@ -356,22 +386,20 @@ def get_axis(figsize, size, aspect, ax): if figsize is not None: if ax is not None: - raise ValueError('cannot provide both `figsize` and ' - '`ax` arguments') + raise ValueError("cannot provide both `figsize` and " "`ax` arguments") if size is not None: - raise ValueError('cannot provide both `figsize` and ' - '`size` arguments') + raise ValueError("cannot provide both `figsize` and " "`size` arguments") _, ax = plt.subplots(figsize=figsize) elif size is not None: if ax is not None: - raise ValueError('cannot provide both `size` and `ax` arguments') + raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: - width, height = mpl.rcParams['figure.figsize'] + width, height = mpl.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) _, ax = plt.subplots(figsize=figsize) elif aspect is not None: - raise ValueError('cannot provide `aspect` argument without `size`') + raise ValueError("cannot provide `aspect` argument without `size`") if ax is None: ax = plt.gca() @@ -379,25 +407,25 @@ def get_axis(figsize, size, aspect, ax): return ax -def label_from_attrs(da, extra=''): - ''' Makes informative labels if variable metadata (attrs) follows - CF conventions. ''' +def label_from_attrs(da, extra=""): + """ Makes informative labels if variable metadata (attrs) follows + CF conventions. """ - if da.attrs.get('long_name'): - name = da.attrs['long_name'] - elif da.attrs.get('standard_name'): - name = da.attrs['standard_name'] + if da.attrs.get("long_name"): + name = da.attrs["long_name"] + elif da.attrs.get("standard_name"): + name = da.attrs["standard_name"] elif da.name is not None: name = da.name else: - name = '' + name = "" - if da.attrs.get('units'): - units = ' [{}]'.format(da.attrs['units']) + if da.attrs.get("units"): + units = " [{}]".format(da.attrs["units"]) else: - units = '' + units = "" - return '\n'.join(textwrap.wrap(name + extra + units, 30)) + return "\n".join(textwrap.wrap(name + extra + units, 30)) def _interval_to_mid_points(array): @@ -416,8 +444,7 @@ def _interval_to_bound_points(array): """ array_boundaries = np.array([x.left for x in array]) - array_boundaries = np.concatenate( - (array_boundaries, np.array([array[-1].right]))) + array_boundaries = np.concatenate((array_boundaries, np.array([array[-1].right]))) return array_boundaries @@ -444,13 +471,13 @@ def _resolve_intervals_2dplot(val, func_name): pd.Interval with their mid-points or - for pcolormesh - boundaries which increases length by 1. """ - label_extra = '' + label_extra = "" if _valid_other_type(val, [pd.Interval]): - if func_name == 'pcolormesh': + if func_name == "pcolormesh": val = _interval_to_bound_points(val) else: val = _interval_to_mid_points(val) - label_extra = '_center' + label_extra = "_center" return val, label_extra @@ -483,24 +510,33 @@ def _ensure_plottable(*args): other_types = [datetime] try: import cftime + cftime_datetime = [cftime.datetime] except ImportError: cftime_datetime = [] other_types = other_types + cftime_datetime for x in args: - if not (_valid_numpy_subdtype(np.array(x), numpy_types) - or _valid_other_type(np.array(x), other_types)): - raise TypeError('Plotting requires coordinates to be numeric ' - 'or dates of type np.datetime64, ' - 'datetime.datetime, cftime.datetime or ' - 'pd.Interval.') - if (_valid_other_type(np.array(x), cftime_datetime) - and not nc_time_axis_available): - raise ImportError('Plotting of arrays of cftime.datetime ' - 'objects or arrays indexed by ' - 'cftime.datetime objects requires the ' - 'optional `nc-time-axis` (v1.2.0 or later) ' - 'package.') + if not ( + _valid_numpy_subdtype(np.array(x), numpy_types) + or _valid_other_type(np.array(x), other_types) + ): + raise TypeError( + "Plotting requires coordinates to be numeric " + "or dates of type np.datetime64, " + "datetime.datetime, cftime.datetime or " + "pd.Interval." + ) + if ( + _valid_other_type(np.array(x), cftime_datetime) + and not nc_time_axis_available + ): + raise ImportError( + "Plotting of arrays of cftime.datetime " + "objects or arrays indexed by " + "cftime.datetime objects requires the " + "optional `nc-time-axis` (v1.2.0 or later) " + "package." + ) def _is_numeric(arr): @@ -510,11 +546,11 @@ def _is_numeric(arr): def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): plt = import_matplotlib_pyplot() - cbar_kwargs.setdefault('extend', cmap_params['extend']) + cbar_kwargs.setdefault("extend", cmap_params["extend"]) if cbar_ax is None: - cbar_kwargs.setdefault('ax', ax) + cbar_kwargs.setdefault("ax", ax) else: - cbar_kwargs.setdefault('cax', cbar_ax) + cbar_kwargs.setdefault("cax", cbar_ax) cbar = plt.colorbar(primitive, **cbar_kwargs) @@ -540,29 +576,37 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 if vmax < vmin: raise ValueError( - 'vmin=%r is less than the default vmax (%r) - you must supply ' - 'a vmax > vmin in this case.' % (vmin, vmax)) + "vmin=%r is less than the default vmax (%r) - you must supply " + "a vmax > vmin in this case." % (vmin, vmax) + ) elif vmin is None: vmin = 0 if vmin > vmax: raise ValueError( - 'vmax=%r is less than the default vmin (0) - you must supply ' - 'a vmin < vmax in this case.' % vmax) + "vmax=%r is less than the default vmin (0) - you must supply " + "a vmin < vmax in this case." % vmax + ) # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float # to avoid precision loss, integer over/underflow, etc with extreme inputs. # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. - darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') + darray = ((darray.astype("f8") - vmin) / (vmax - vmin)).astype("f4") with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'xarray.ufuncs', - PendingDeprecationWarning) + warnings.filterwarnings("ignore", "xarray.ufuncs", PendingDeprecationWarning) return minimum(maximum(darray, 0), 1) -def _update_axes(ax, xincrease, yincrease, - xscale=None, yscale=None, - xticks=None, yticks=None, - xlim=None, ylim=None): +def _update_axes( + ax, + xincrease, + yincrease, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, +): """ Update axes with provided parameters """ @@ -614,10 +658,12 @@ def _is_monotonic(coord, axis=0): return True else: n = coord.shape[axis] - delta_pos = (coord.take(np.arange(1, n), axis=axis) >= - coord.take(np.arange(0, n - 1), axis=axis)) - delta_neg = (coord.take(np.arange(1, n), axis=axis) <= - coord.take(np.arange(0, n - 1), axis=axis)) + delta_pos = coord.take(np.arange(1, n), axis=axis) >= coord.take( + np.arange(0, n - 1), axis=axis + ) + delta_neg = coord.take(np.arange(1, n), axis=axis) <= coord.take( + np.arange(0, n - 1), axis=axis + ) return np.all(delta_pos) or np.all(delta_neg) @@ -632,28 +678,35 @@ def _infer_interval_breaks(coord, axis=0, check_monotonic=False): coord = np.asarray(coord) if check_monotonic and not _is_monotonic(coord, axis=axis): - raise ValueError("The input coordinate is not sorted in increasing " - "order along axis %d. This can lead to unexpected " - "results. Consider calling the `sortby` method on " - "the input DataArray. To plot data with categorical " - "axes, consider using the `heatmap` function from " - "the `seaborn` statistical plotting library." % axis) + raise ValueError( + "The input coordinate is not sorted in increasing " + "order along axis %d. This can lead to unexpected " + "results. Consider calling the `sortby` method on " + "the input DataArray. To plot data with categorical " + "axes, consider using the `heatmap` function from " + "the `seaborn` statistical plotting library." % axis + ) deltas = 0.5 * np.diff(coord, axis=axis) if deltas.size == 0: deltas = np.array(0.0) first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis) last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis) - trim_last = tuple(slice(None, -1) if n == axis else slice(None) - for n in range(coord.ndim)) + trim_last = tuple( + slice(None, -1) if n == axis else slice(None) for n in range(coord.ndim) + ) return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis) def _process_cmap_cbar_kwargs( - func, data, cmap=None, colors=None, - cbar_kwargs: Union[Iterable[Tuple[str, Any]], - Mapping[str, Any]] = None, - levels=None, **kwargs): + func, + data, + cmap=None, + colors=None, + cbar_kwargs: Union[Iterable[Tuple[str, Any]], Mapping[str, Any]] = None, + levels=None, + **kwargs +): """ Parameters ========== @@ -669,7 +722,7 @@ def _process_cmap_cbar_kwargs( """ cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) - if 'contour' in func.__name__ and levels is None: + if "contour" in func.__name__ and levels is None: levels = 7 # this is the matplotlib default # colors is mutually exclusive with cmap @@ -678,20 +731,25 @@ def _process_cmap_cbar_kwargs( # colors is only valid when levels is supplied or the plot is of type # contour or contourf - if colors and (('contour' not in func.__name__) and (not levels)): + if colors and (("contour" not in func.__name__) and (not levels)): raise ValueError("Can only specify colors with contour or levels") # we should not be getting a list of colors in cmap anymore # is there a better way to do this test? if isinstance(cmap, (list, tuple)): - warnings.warn("Specifying a list of colors in cmap is deprecated. " - "Use colors keyword instead.", - DeprecationWarning, stacklevel=3) - - cmap_kwargs = {'plot_data': data, - 'levels': levels, - 'cmap': colors if colors else cmap, - 'filled': func.__name__ != 'contour'} + warnings.warn( + "Specifying a list of colors in cmap is deprecated. " + "Use colors keyword instead.", + DeprecationWarning, + stacklevel=3, + ) + + cmap_kwargs = { + "plot_data": data, + "levels": levels, + "cmap": colors if colors else cmap, + "filled": func.__name__ != "contour", + } cmap_args = getfullargspec(_determine_cmap_params).args cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs) diff --git a/xarray/testing.py b/xarray/testing.py index 42c91b1eda2..3c92eef04c6 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -14,22 +14,20 @@ def _decode_string_data(data): - if data.dtype.kind == 'S': - return np.core.defchararray.decode(data, 'utf-8', 'replace') + if data.dtype.kind == "S": + return np.core.defchararray.decode(data, "utf-8", "replace") return data -def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, - decode_bytes=True): - if any(arr.dtype.kind == 'S' for arr in [arr1, arr2]) and decode_bytes: +def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=True): + if any(arr.dtype.kind == "S" for arr in [arr1, arr2]) and decode_bytes: arr1 = _decode_string_data(arr1) arr2 = _decode_string_data(arr2) - exact_dtypes = ['M', 'm', 'O', 'S', 'U'] + exact_dtypes = ["M", "m", "O", "S", "U"] if any(arr.dtype.kind in exact_dtypes for arr in [arr1, arr2]): return duck_array_ops.array_equiv(arr1, arr2) else: - return duck_array_ops.allclose_or_equiv( - arr1, arr2, rtol=rtol, atol=atol) + return duck_array_ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) def assert_equal(a, b): @@ -56,12 +54,11 @@ def assert_equal(a, b): __tracebackhide__ = True # noqa: F841 assert type(a) == type(b) # noqa if isinstance(a, (Variable, DataArray)): - assert a.equals(b), formatting.diff_array_repr(a, b, 'equals') + assert a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): - assert a.equals(b), formatting.diff_dataset_repr(a, b, 'equals') + assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") else: - raise TypeError('{} not supported by assertion comparison' - .format(type(a))) + raise TypeError("{} not supported by assertion comparison".format(type(a))) def assert_identical(a, b): @@ -84,15 +81,14 @@ def assert_identical(a, b): __tracebackhide__ = True # noqa: F841 assert type(a) == type(b) # noqa if isinstance(a, Variable): - assert a.identical(b), formatting.diff_array_repr(a, b, 'identical') + assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, DataArray): assert a.name == b.name - assert a.identical(b), formatting.diff_array_repr(a, b, 'identical') + assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, (Dataset, Variable)): - assert a.identical(b), formatting.diff_dataset_repr(a, b, 'identical') + assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") else: - raise TypeError('{} not supported by assertion comparison' - .format(type(a))) + raise TypeError("{} not supported by assertion comparison".format(type(a))) def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): @@ -126,17 +122,17 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): if isinstance(a, Variable): assert a.dims == b.dims allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs) - assert allclose, '{}\n{}'.format(a.values, b.values) + assert allclose, "{}\n{}".format(a.values, b.values) elif isinstance(a, DataArray): assert_allclose(a.variable, b.variable, **kwargs) assert set(a.coords) == set(b.coords) for v in a.coords.variables: # can't recurse with this function as coord is sometimes a # DataArray, so call into _data_allclose_or_equiv directly - allclose = _data_allclose_or_equiv(a.coords[v].values, - b.coords[v].values, **kwargs) - assert allclose, '{}\n{}'.format(a.coords[v].values, - b.coords[v].values) + allclose = _data_allclose_or_equiv( + a.coords[v].values, b.coords[v].values, **kwargs + ) + assert allclose, "{}\n{}".format(a.coords[v].values, b.coords[v].values) elif isinstance(a, Dataset): assert set(a.data_vars) == set(b.data_vars) assert set(a.coords) == set(b.coords) @@ -144,26 +140,25 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): assert_allclose(a[k], b[k], **kwargs) else: - raise TypeError('{} not supported by assertion comparison' - .format(type(a))) + raise TypeError("{} not supported by assertion comparison".format(type(a))) def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): assert isinstance(indexes, OrderedDict), indexes - assert all(isinstance(v, pd.Index) for v in indexes.values()), \ - {k: type(v) for k, v in indexes.items()} + assert all(isinstance(v, pd.Index) for v in indexes.values()), { + k: type(v) for k, v in indexes.items() + } - index_vars = {k for k, v in possible_coord_variables.items() - if isinstance(v, IndexVariable)} + index_vars = { + k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable) + } assert indexes.keys() <= index_vars, (set(indexes), index_vars) # Note: when we support non-default indexes, these checks should be opt-in # only! defaults = default_indexes(possible_coord_variables, dims) - assert indexes.keys() == defaults.keys(), \ - (set(indexes), set(defaults)) - assert all(v.equals(defaults[k]) for k, v in indexes.items()), \ - (indexes, defaults) + assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) + assert all(v.equals(defaults[k]) for k, v in indexes.items()), (indexes, defaults) def _assert_variable_invariants(var: Variable, name: Hashable = None): @@ -172,12 +167,16 @@ def _assert_variable_invariants(var: Variable, name: Hashable = None): else: name_or_empty = (name,) assert isinstance(var._dims, tuple), name_or_empty + (var._dims,) - assert len(var._dims) == len(var._data.shape), \ - name_or_empty + (var._dims, var._data.shape) - assert isinstance(var._encoding, (type(None), dict)), \ - name_or_empty + (var._encoding,) - assert isinstance(var._attrs, (type(None), OrderedDict)), \ - name_or_empty + (var._attrs,) + assert len(var._dims) == len(var._data.shape), name_or_empty + ( + var._dims, + var._data.shape, + ) + assert isinstance(var._encoding, (type(None), dict)), name_or_empty + ( + var._encoding, + ) + assert isinstance(var._attrs, (type(None), OrderedDict)), name_or_empty + ( + var._attrs, + ) def _assert_dataarray_invariants(da: DataArray): @@ -185,14 +184,14 @@ def _assert_dataarray_invariants(da: DataArray): _assert_variable_invariants(da._variable) assert isinstance(da._coords, OrderedDict), da._coords + assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords + assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( + da.dims, + {k: v.dims for k, v in da._coords.items()}, + ) assert all( - isinstance(v, Variable) for v in da._coords.values()), da._coords - assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), \ - (da.dims, {k: v.dims for k, v in da._coords.items()}) - assert all(isinstance(v, IndexVariable) - for (k, v) in da._coords.items() - if v.dims == (k,)), \ - {k: type(v) for k, v in da._coords.items()} + isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,) + ), {k: type(v) for k, v in da._coords.items()} for k, v in da._coords.items(): _assert_variable_invariants(v, k) @@ -204,15 +203,15 @@ def _assert_dataarray_invariants(da: DataArray): def _assert_dataset_invariants(ds: Dataset): assert isinstance(ds._variables, OrderedDict), type(ds._variables) - assert all( - isinstance(v, Variable) for v in ds._variables.values()), \ - ds._variables + assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables for k, v in ds._variables.items(): _assert_variable_invariants(v, k) assert isinstance(ds._coord_names, set), ds._coord_names - assert ds._coord_names <= ds._variables.keys(), \ - (ds._coord_names, set(ds._variables)) + assert ds._coord_names <= ds._variables.keys(), ( + ds._coord_names, + set(ds._variables), + ) assert type(ds._dims) is dict, ds._dims assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims @@ -220,18 +219,17 @@ def _assert_dataset_invariants(ds: Dataset): for v in ds._variables.values(): var_dims.update(v.dims) assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims) - assert all(ds._dims[k] == v.sizes[k] - for v in ds._variables.values() - for k in v.sizes), \ - (ds._dims, {k: v.sizes for k, v in ds._variables.items()}) - assert all(isinstance(v, IndexVariable) - for (k, v) in ds._variables.items() - if v.dims == (k,)), \ - {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)} - assert all(v.dims == (k,) - for (k, v) in ds._variables.items() - if k in ds._dims), \ - {k: v.dims for k, v in ds._variables.items() if k in ds._dims} + assert all( + ds._dims[k] == v.sizes[k] for v in ds._variables.values() for k in v.sizes + ), (ds._dims, {k: v.sizes for k, v in ds._variables.items()}) + assert all( + isinstance(v, IndexVariable) + for (k, v) in ds._variables.items() + if v.dims == (k,) + ), {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)} + assert all(v.dims == (k,) for (k, v) in ds._variables.items() if k in ds._dims), { + k: v.dims for k, v in ds._variables.items() if k in ds._dims + } if ds._indexes is not None: _assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims) @@ -241,9 +239,7 @@ def _assert_dataset_invariants(ds: Dataset): assert ds._initialized is True -def _assert_internal_invariants( - xarray_obj: Union[DataArray, Dataset, Variable], -): +def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable],): """Validate that an xarray object satisfies its own internal invariants. This exists for the benefit of xarray's own test suite, but may be useful @@ -258,5 +254,7 @@ def _assert_internal_invariants( _assert_dataset_invariants(xarray_obj) else: raise TypeError( - '{} is not a supported type for xarray invariant checks' - .format(type(xarray_obj))) + "{} is not a supported type for xarray invariant checks".format( + type(xarray_obj) + ) + ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 4e50a8bcfe1..044ba75e87f 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -25,16 +25,19 @@ # import mpl and change the backend before other mpl imports try: import matplotlib as mpl + # Order of imports is important here. # Using a different backend makes Travis CI work - mpl.use('Agg') + mpl.use("Agg") except ImportError: pass import platform -arm_xfail = pytest.mark.xfail(platform.machine() == 'aarch64' or - 'arm' in platform.machine(), - reason='expected failure on ARM') + +arm_xfail = pytest.mark.xfail( + platform.machine() == "aarch64" or "arm" in platform.machine(), + reason="expected failure on ARM", +) def _importorskip(modname, minversion=None): @@ -43,76 +46,79 @@ def _importorskip(modname, minversion=None): has = True if minversion is not None: if LooseVersion(mod.__version__) < LooseVersion(minversion): - raise ImportError('Minimum version not satisfied') + raise ImportError("Minimum version not satisfied") except ImportError: has = False - func = pytest.mark.skipif(not has, reason='requires {}'.format(modname)) + func = pytest.mark.skipif(not has, reason="requires {}".format(modname)) return has, func def LooseVersion(vstring): # Our development version is something like '0.10.9+aac7bfc' # This function just ignored the git commit id. - vstring = vstring.split('+')[0] + vstring = vstring.split("+")[0] return version.LooseVersion(vstring) -has_matplotlib, requires_matplotlib = _importorskip('matplotlib') -has_matplotlib2, requires_matplotlib2 = _importorskip('matplotlib', - minversion='2') -has_scipy, requires_scipy = _importorskip('scipy') -has_pydap, requires_pydap = _importorskip('pydap.client') -has_netCDF4, requires_netCDF4 = _importorskip('netCDF4') -has_h5netcdf, requires_h5netcdf = _importorskip('h5netcdf') -has_pynio, requires_pynio = _importorskip('Nio') -has_pseudonetcdf, requires_pseudonetcdf = _importorskip('PseudoNetCDF') -has_cftime, requires_cftime = _importorskip('cftime') -has_nc_time_axis, requires_nc_time_axis = _importorskip('nc_time_axis', - minversion='1.2.0') +has_matplotlib, requires_matplotlib = _importorskip("matplotlib") +has_matplotlib2, requires_matplotlib2 = _importorskip("matplotlib", minversion="2") +has_scipy, requires_scipy = _importorskip("scipy") +has_pydap, requires_pydap = _importorskip("pydap.client") +has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") +has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") +has_pynio, requires_pynio = _importorskip("Nio") +has_pseudonetcdf, requires_pseudonetcdf = _importorskip("PseudoNetCDF") +has_cftime, requires_cftime = _importorskip("cftime") +has_nc_time_axis, requires_nc_time_axis = _importorskip( + "nc_time_axis", minversion="1.2.0" +) has_cftime_1_0_2_1, requires_cftime_1_0_2_1 = _importorskip( - 'cftime', minversion='1.0.2.1') -has_dask, requires_dask = _importorskip('dask') -has_bottleneck, requires_bottleneck = _importorskip('bottleneck') -has_rasterio, requires_rasterio = _importorskip('rasterio') -has_pathlib, requires_pathlib = _importorskip('pathlib') -has_zarr, requires_zarr = _importorskip('zarr', minversion='2.2') -has_np113, requires_np113 = _importorskip('numpy', minversion='1.13.0') -has_iris, requires_iris = _importorskip('iris') -has_cfgrib, requires_cfgrib = _importorskip('cfgrib') -has_numbagg, requires_numbagg = _importorskip('numbagg') + "cftime", minversion="1.0.2.1" +) +has_dask, requires_dask = _importorskip("dask") +has_bottleneck, requires_bottleneck = _importorskip("bottleneck") +has_rasterio, requires_rasterio = _importorskip("rasterio") +has_pathlib, requires_pathlib = _importorskip("pathlib") +has_zarr, requires_zarr = _importorskip("zarr", minversion="2.2") +has_np113, requires_np113 = _importorskip("numpy", minversion="1.13.0") +has_iris, requires_iris = _importorskip("iris") +has_cfgrib, requires_cfgrib = _importorskip("cfgrib") +has_numbagg, requires_numbagg = _importorskip("numbagg") # some special cases -has_h5netcdf07, requires_h5netcdf07 = _importorskip('h5netcdf', - minversion='0.7') -has_h5py29, requires_h5py29 = _importorskip('h5py', minversion='2.9.0') +has_h5netcdf07, requires_h5netcdf07 = _importorskip("h5netcdf", minversion="0.7") +has_h5py29, requires_h5py29 = _importorskip("h5py", minversion="2.9.0") has_h5fileobj = has_h5netcdf07 and has_h5py29 requires_h5fileobj = pytest.mark.skipif( - not has_h5fileobj, reason='requires h5py>2.9.0 & h5netcdf>0.7') + not has_h5fileobj, reason="requires h5py>2.9.0 & h5netcdf>0.7" +) has_scipy_or_netCDF4 = has_scipy or has_netCDF4 requires_scipy_or_netCDF4 = pytest.mark.skipif( - not has_scipy_or_netCDF4, reason='requires scipy or netCDF4') + not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" +) has_cftime_or_netCDF4 = has_cftime or has_netCDF4 requires_cftime_or_netCDF4 = pytest.mark.skipif( - not has_cftime_or_netCDF4, reason='requires cftime or netCDF4') + not has_cftime_or_netCDF4, reason="requires cftime or netCDF4" +) if not has_pathlib: - has_pathlib, requires_pathlib = _importorskip('pathlib2') + has_pathlib, requires_pathlib = _importorskip("pathlib2") try: import_seaborn() has_seaborn = True except ImportError: has_seaborn = False -requires_seaborn = pytest.mark.skipif(not has_seaborn, - reason='requires seaborn') +requires_seaborn = pytest.mark.skipif(not has_seaborn, reason="requires seaborn") # change some global options for tests set_options(warn_for_unclosed_files=True) if has_dask: import dask - if LooseVersion(dask.__version__) < '0.18': + + if LooseVersion(dask.__version__) < "0.18": dask.set_options(get=dask.get) else: - dask.config.set(scheduler='single-threaded') + dask.config.set(scheduler="single-threaded") flaky = pytest.mark.flaky network = pytest.mark.network @@ -125,8 +131,9 @@ def raises_regex(error, pattern): yield message = str(excinfo.value) if not re.search(pattern, message): - raise AssertionError('exception %r did not match pattern %r' - % (excinfo.value, pattern)) + raise AssertionError( + "exception %r did not match pattern %r" % (excinfo.value, pattern) + ) class UnexpectedDataAccess(Exception): @@ -134,7 +141,6 @@ class UnexpectedDataAccess(Exception): class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed): - def __init__(self, array): self.array = array @@ -143,13 +149,11 @@ def __getitem__(self, key): class ReturnItem: - def __getitem__(self, key): return key class IndexerMaker: - def __init__(self, indexer_cls): self._indexer_cls = indexer_cls @@ -164,9 +168,9 @@ def source_ndarray(array): object itself. """ with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'DatetimeIndex.base') - warnings.filterwarnings('ignore', 'TimedeltaIndex.base') - base = getattr(array, 'base', np.asarray(array).base) + warnings.filterwarnings("ignore", "DatetimeIndex.base") + warnings.filterwarnings("ignore", "TimedeltaIndex.base") + base = getattr(array, "base", np.asarray(array).base) if base is None: base = array return base @@ -175,6 +179,7 @@ def source_ndarray(array): # Internal versions of xarray's test functions that validate additional # invariants + def assert_equal(a, b): xarray.testing.assert_equal(a, b) xarray.testing._assert_internal_invariants(a) diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 09041a6a69f..0058747db71 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -5,8 +5,14 @@ import xarray as xr from . import ( - assert_array_equal, assert_equal, has_cftime, has_cftime_or_netCDF4, - has_dask, raises_regex, requires_dask) + assert_array_equal, + assert_equal, + has_cftime, + has_cftime_or_netCDF4, + has_dask, + raises_regex, + requires_dask, +) class TestDatetimeAccessor: @@ -16,26 +22,36 @@ def setup(self): data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) lats = np.linspace(0, 20, 10) - self.times = pd.date_range(start="2000/01/01", freq='H', periods=nt) + self.times = pd.date_range(start="2000/01/01", freq="H", periods=nt) - self.data = xr.DataArray(data, coords=[lons, lats, self.times], - dims=['lon', 'lat', 'time'], name='data') + self.data = xr.DataArray( + data, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) self.times_arr = np.random.choice(self.times, size=(10, 10, nt)) - self.times_data = xr.DataArray(self.times_arr, - coords=[lons, lats, self.times], - dims=['lon', 'lat', 'time'], - name='data') + self.times_data = xr.DataArray( + self.times_arr, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) def test_field_access(self): - years = xr.DataArray(self.times.year, name='year', - coords=[self.times, ], dims=['time', ]) - months = xr.DataArray(self.times.month, name='month', - coords=[self.times, ], dims=['time', ]) - days = xr.DataArray(self.times.day, name='day', - coords=[self.times, ], dims=['time', ]) - hours = xr.DataArray(self.times.hour, name='hour', - coords=[self.times, ], dims=['time', ]) + years = xr.DataArray( + self.times.year, name="year", coords=[self.times], dims=["time"] + ) + months = xr.DataArray( + self.times.month, name="month", coords=[self.times], dims=["time"] + ) + days = xr.DataArray( + self.times.day, name="day", coords=[self.times], dims=["time"] + ) + hours = xr.DataArray( + self.times.hour, name="hour", coords=[self.times], dims=["time"] + ) assert_equal(years, self.data.time.dt.year) assert_equal(months, self.data.time.dt.month) @@ -43,14 +59,15 @@ def test_field_access(self): assert_equal(hours, self.data.time.dt.hour) def test_strftime(self): - assert ('2000-01-01 01:00:00' == self.data.time.dt.strftime( - '%Y-%m-%d %H:%M:%S')[1]) + assert ( + "2000-01-01 01:00:00" == self.data.time.dt.strftime("%Y-%m-%d %H:%M:%S")[1] + ) def test_not_datetime_type(self): nontime_data = self.data.copy() - int_data = np.arange(len(self.data.time)).astype('int8') - nontime_data['time'].values = int_data - with raises_regex(TypeError, 'dt'): + int_data = np.arange(len(self.data.time)).astype("int8") + nontime_data["time"].values = int_data + with raises_regex(TypeError, "dt"): nontime_data.time.dt @requires_dask @@ -61,24 +78,23 @@ def test_dask_field_access(self): months = self.times_data.dt.month hours = self.times_data.dt.hour days = self.times_data.dt.day - floor = self.times_data.dt.floor('D') - ceil = self.times_data.dt.ceil('D') - round = self.times_data.dt.round('D') - strftime = self.times_data.dt.strftime('%Y-%m-%d %H:%M:%S') + floor = self.times_data.dt.floor("D") + ceil = self.times_data.dt.ceil("D") + round = self.times_data.dt.round("D") + strftime = self.times_data.dt.strftime("%Y-%m-%d %H:%M:%S") dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) - dask_times_2d = xr.DataArray(dask_times_arr, - coords=self.data.coords, - dims=self.data.dims, - name='data') + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) dask_year = dask_times_2d.dt.year dask_month = dask_times_2d.dt.month dask_day = dask_times_2d.dt.day dask_hour = dask_times_2d.dt.hour - dask_floor = dask_times_2d.dt.floor('D') - dask_ceil = dask_times_2d.dt.ceil('D') - dask_round = dask_times_2d.dt.round('D') - dask_strftime = dask_times_2d.dt.strftime('%Y-%m-%d %H:%M:%S') + dask_floor = dask_times_2d.dt.floor("D") + dask_ceil = dask_times_2d.dt.ceil("D") + dask_round = dask_times_2d.dt.round("D") + dask_strftime = dask_times_2d.dt.strftime("%Y-%m-%d %H:%M:%S") # Test that the data isn't eagerly evaluated assert isinstance(dask_year.data, da.Array) @@ -108,26 +124,41 @@ def test_dask_field_access(self): def test_seasons(self): dates = pd.date_range(start="2000/01/01", freq="M", periods=12) dates = xr.DataArray(dates) - seasons = ["DJF", "DJF", "MAM", "MAM", "MAM", "JJA", "JJA", "JJA", - "SON", "SON", "SON", "DJF"] + seasons = [ + "DJF", + "DJF", + "MAM", + "MAM", + "MAM", + "JJA", + "JJA", + "JJA", + "SON", + "SON", + "SON", + "DJF", + ] seasons = xr.DataArray(seasons) assert_array_equal(seasons.values, dates.dt.season.values) def test_rounders(self): - dates = pd.date_range("2014-01-01", "2014-05-01", freq='H') - xdates = xr.DataArray(np.arange(len(dates)), - dims=['time'], coords=[dates]) - assert_array_equal(dates.floor('D').values, - xdates.time.dt.floor('D').values) - assert_array_equal(dates.ceil('D').values, - xdates.time.dt.ceil('D').values) - assert_array_equal(dates.round('D').values, - xdates.time.dt.round('D').values) - - -_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', - '366_day', 'gregorian', 'proleptic_gregorian'] + dates = pd.date_range("2014-01-01", "2014-05-01", freq="H") + xdates = xr.DataArray(np.arange(len(dates)), dims=["time"], coords=[dates]) + assert_array_equal(dates.floor("D").values, xdates.time.dt.floor("D").values) + assert_array_equal(dates.ceil("D").values, xdates.time.dt.ceil("D").values) + assert_array_equal(dates.round("D").values, xdates.time.dt.round("D").values) + + +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", +] _NT = 100 @@ -141,8 +172,11 @@ def times(calendar): import cftime return cftime.num2date( - np.arange(_NT), units='hours since 2000-01-01', calendar=calendar, - only_use_cftime_datetimes=True) + np.arange(_NT), + units="hours since 2000-01-01", + calendar=calendar, + only_use_cftime_datetimes=True, + ) @pytest.fixture() @@ -150,8 +184,9 @@ def data(times): data = np.random.rand(10, 10, _NT) lons = np.linspace(0, 11, 10) lats = np.linspace(0, 20, 10) - return xr.DataArray(data, coords=[lons, lats, times], - dims=['lon', 'lat', 'time'], name='data') + return xr.DataArray( + data, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) @pytest.fixture() @@ -159,73 +194,85 @@ def times_3d(times): lons = np.linspace(0, 11, 10) lats = np.linspace(0, 20, 10) times_arr = np.random.choice(times, size=(10, 10, _NT)) - return xr.DataArray(times_arr, coords=[lons, lats, times], - dims=['lon', 'lat', 'time'], - name='data') + return xr.DataArray( + times_arr, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour', - 'dayofyear', 'dayofweek']) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] +) def test_field_access(data, field): - if field == 'dayofyear' or field == 'dayofweek': - pytest.importorskip('cftime', minversion='1.0.2.1') + if field == "dayofyear" or field == "dayofweek": + pytest.importorskip("cftime", minversion="1.0.2.1") result = getattr(data.time.dt, field) expected = xr.DataArray( getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), - name=field, coords=data.time.coords, dims=data.time.dims) + name=field, + coords=data.time.coords, + dims=data.time.dims, + ) assert_equal(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftime_strftime_access(data): """ compare cftime formatting against datetime formatting """ - date_format = '%Y%m%d%H' + date_format = "%Y%m%d%H" result = data.time.dt.strftime(date_format) datetime_array = xr.DataArray( - xr.coding.cftimeindex.CFTimeIndex( - data.time.values).to_datetimeindex(), + xr.coding.cftimeindex.CFTimeIndex(data.time.values).to_datetimeindex(), name="stftime", coords=data.time.coords, - dims=data.time.dims) + dims=data.time.dims, + ) expected = datetime_array.dt.strftime(date_format) assert_equal(result, expected) -@pytest.mark.skipif(not has_dask, reason='dask not installed') -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour', - 'dayofyear', 'dayofweek']) +@pytest.mark.skipif(not has_dask, reason="dask not installed") +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] +) def test_dask_field_access_1d(data, field): import dask.array as da - if field == 'dayofyear' or field == 'dayofweek': - pytest.importorskip('cftime', minversion='1.0.2.1') + if field == "dayofyear" or field == "dayofweek": + pytest.importorskip("cftime", minversion="1.0.2.1") expected = xr.DataArray( getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), - name=field, dims=['time']) - times = xr.DataArray(data.time.values, dims=['time']).chunk({'time': 50}) + name=field, + dims=["time"], + ) + times = xr.DataArray(data.time.values, dims=["time"]).chunk({"time": 50}) result = getattr(times.dt, field) assert isinstance(result.data, da.Array) assert result.chunks == times.chunks assert_equal(result.compute(), expected) -@pytest.mark.skipif(not has_dask, reason='dask not installed') -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour', 'dayofyear', - 'dayofweek']) +@pytest.mark.skipif(not has_dask, reason="dask not installed") +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] +) def test_dask_field_access(times_3d, data, field): import dask.array as da - if field == 'dayofyear' or field == 'dayofweek': - pytest.importorskip('cftime', minversion='1.0.2.1') + if field == "dayofyear" or field == "dayofweek": + pytest.importorskip("cftime", minversion="1.0.2.1") expected = xr.DataArray( - getattr(xr.coding.cftimeindex.CFTimeIndex(times_3d.values.ravel()), - field).reshape(times_3d.shape), - name=field, coords=times_3d.coords, dims=times_3d.dims) - times_3d = times_3d.chunk({'lon': 5, 'lat': 5, 'time': 50}) + getattr( + xr.coding.cftimeindex.CFTimeIndex(times_3d.values.ravel()), field + ).reshape(times_3d.shape), + name=field, + coords=times_3d.coords, + dims=times_3d.dims, + ) + times_3d = times_3d.chunk({"lon": 5, "lat": 5, "time": 50}) result = getattr(times_3d.dt, field) assert isinstance(result.data, da.Array) assert result.chunks == times_3d.chunks @@ -239,24 +286,34 @@ def cftime_date_type(calendar): return _all_cftime_date_types()[calendar] -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_seasons(cftime_date_type): - dates = np.array([cftime_date_type(2000, month, 15) - for month in range(1, 13)]) + dates = np.array([cftime_date_type(2000, month, 15) for month in range(1, 13)]) dates = xr.DataArray(dates) - seasons = ['DJF', 'DJF', 'MAM', 'MAM', 'MAM', 'JJA', - 'JJA', 'JJA', 'SON', 'SON', 'SON', 'DJF'] + seasons = [ + "DJF", + "DJF", + "MAM", + "MAM", + "MAM", + "JJA", + "JJA", + "JJA", + "SON", + "SON", + "SON", + "DJF", + ] seasons = xr.DataArray(seasons) assert_array_equal(seasons.values, dates.dt.season.values) -@pytest.mark.skipif(not has_cftime_or_netCDF4, - reason='cftime or netCDF4 not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime or netCDF4 not installed") def test_dt_accessor_error_netCDF4(cftime_date_type): da = xr.DataArray( - [cftime_date_type(1, 1, 1), cftime_date_type(2, 1, 1)], - dims=['time']) + [cftime_date_type(1, 1, 1), cftime_date_type(2, 1, 1)], dims=["time"] + ) if not has_cftime: with pytest.raises(TypeError): da.dt.month diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 800096b806b..360653b229b 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -55,7 +55,8 @@ def dtype(request): @requires_dask def test_dask(): import dask.array as da - arr = da.from_array(['a', 'b', 'c']) + + arr = da.from_array(["a", "b", "c"]) xarr = xr.DataArray(arr) result = xarr.str.len().compute() @@ -64,42 +65,41 @@ def test_dask(): def test_count(dtype): - values = xr.DataArray(['foo', 'foofoo', 'foooofooofommmfoo']).astype(dtype) - result = values.str.count('f[o]+') + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + result = values.str.count("f[o]+") expected = xr.DataArray([1, 2, 4]) assert_equal(result, expected) def test_contains(dtype): - values = xr.DataArray(['Foo', 'xYz', 'fOOomMm__fOo', 'MMM_']).astype(dtype) + values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"]).astype(dtype) # case insensitive using regex - result = values.str.contains('FOO|mmm', case=False) + result = values.str.contains("FOO|mmm", case=False) expected = xr.DataArray([True, False, True, True]) assert_equal(result, expected) # case insensitive without regex - result = values.str.contains('foo', regex=False, case=False) + result = values.str.contains("foo", regex=False, case=False) expected = xr.DataArray([True, False, True, False]) assert_equal(result, expected) def test_starts_ends_with(dtype): - values = xr.DataArray( - ['om', 'foo_nom', 'nom', 'bar_foo', 'foo']).astype(dtype) - result = values.str.startswith('foo') + values = xr.DataArray(["om", "foo_nom", "nom", "bar_foo", "foo"]).astype(dtype) + result = values.str.startswith("foo") expected = xr.DataArray([False, True, False, False, True]) assert_equal(result, expected) - result = values.str.endswith('foo') + result = values.str.endswith("foo") expected = xr.DataArray([False, False, False, True, True]) assert_equal(result, expected) def test_case(dtype): - da = xr.DataArray(['SOme word']).astype(dtype) - capitalized = xr.DataArray(['Some word']).astype(dtype) - lowered = xr.DataArray(['some word']).astype(dtype) - swapped = xr.DataArray(['soME WORD']).astype(dtype) - titled = xr.DataArray(['Some Word']).astype(dtype) - uppered = xr.DataArray(['SOME WORD']).astype(dtype) + da = xr.DataArray(["SOme word"]).astype(dtype) + capitalized = xr.DataArray(["Some word"]).astype(dtype) + lowered = xr.DataArray(["some word"]).astype(dtype) + swapped = xr.DataArray(["soME WORD"]).astype(dtype) + titled = xr.DataArray(["Some Word"]).astype(dtype) + uppered = xr.DataArray(["SOME WORD"]).astype(dtype) assert_equal(da.str.capitalize(), capitalized) assert_equal(da.str.lower(), lowered) assert_equal(da.str.swapcase(), swapped) @@ -108,46 +108,50 @@ def test_case(dtype): def test_replace(dtype): - values = xr.DataArray(['fooBAD__barBAD']).astype(dtype) - result = values.str.replace('BAD[_]*', '') - expected = xr.DataArray(['foobar']).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + result = values.str.replace("BAD[_]*", "") + expected = xr.DataArray(["foobar"]).astype(dtype) assert_equal(result, expected) - result = values.str.replace('BAD[_]*', '', n=1) - expected = xr.DataArray(['foobarBAD']).astype(dtype) + result = values.str.replace("BAD[_]*", "", n=1) + expected = xr.DataArray(["foobarBAD"]).astype(dtype) assert_equal(result, expected) - s = xr.DataArray(['A', 'B', 'C', 'Aaba', 'Baca', '', - 'CABA', 'dog', 'cat']).astype(dtype) - result = s.str.replace('A', 'YYY') - expected = xr.DataArray(['YYY', 'B', 'C', 'YYYaba', 'Baca', '', 'CYYYBYYY', - 'dog', 'cat']).astype(dtype) + s = xr.DataArray(["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"]).astype( + dtype + ) + result = s.str.replace("A", "YYY") + expected = xr.DataArray( + ["YYY", "B", "C", "YYYaba", "Baca", "", "CYYYBYYY", "dog", "cat"] + ).astype(dtype) assert_equal(result, expected) - result = s.str.replace('A', 'YYY', case=False) - expected = xr.DataArray(['YYY', 'B', 'C', 'YYYYYYbYYY', 'BYYYcYYY', - '', 'CYYYBYYY', 'dog', 'cYYYt']).astype(dtype) + result = s.str.replace("A", "YYY", case=False) + expected = xr.DataArray( + ["YYY", "B", "C", "YYYYYYbYYY", "BYYYcYYY", "", "CYYYBYYY", "dog", "cYYYt"] + ).astype(dtype) assert_equal(result, expected) - result = s.str.replace('^.a|dog', 'XX-XX ', case=False) - expected = xr.DataArray(['A', 'B', 'C', 'XX-XX ba', 'XX-XX ca', '', - 'XX-XX BA', 'XX-XX ', 'XX-XX t']).astype(dtype) + result = s.str.replace("^.a|dog", "XX-XX ", case=False) + expected = xr.DataArray( + ["A", "B", "C", "XX-XX ba", "XX-XX ca", "", "XX-XX BA", "XX-XX ", "XX-XX t"] + ).astype(dtype) assert_equal(result, expected) def test_replace_callable(): - values = xr.DataArray(['fooBAD__barBAD']) + values = xr.DataArray(["fooBAD__barBAD"]) # test with callable repl = lambda m: m.group(0).swapcase() # noqa - result = values.str.replace('[a-z][A-Z]{2}', repl, n=2) - exp = xr.DataArray(['foObaD__baRbaD']) + result = values.str.replace("[a-z][A-Z]{2}", repl, n=2) + exp = xr.DataArray(["foObaD__baRbaD"]) assert_equal(result, exp) # test regex named groups - values = xr.DataArray(['Foo Bar Baz']) + values = xr.DataArray(["Foo Bar Baz"]) pat = r"(?P\w+) (?P\w+) (?P\w+)" - repl = lambda m: m.group('middle').swapcase() # noqa + repl = lambda m: m.group("middle").swapcase() # noqa result = values.str.replace(pat, repl) - exp = xr.DataArray(['bAR']) + exp = xr.DataArray(["bAR"]) assert_equal(result, exp) @@ -161,105 +165,105 @@ def test_replace_unicode(): def test_replace_compiled_regex(dtype): - values = xr.DataArray(['fooBAD__barBAD']).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) # test with compiled regex - pat = re.compile(dtype('BAD[_]*')) - result = values.str.replace(pat, '') - expected = xr.DataArray(['foobar']).astype(dtype) + pat = re.compile(dtype("BAD[_]*")) + result = values.str.replace(pat, "") + expected = xr.DataArray(["foobar"]).astype(dtype) assert_equal(result, expected) - result = values.str.replace(pat, '', n=1) - expected = xr.DataArray(['foobarBAD']).astype(dtype) + result = values.str.replace(pat, "", n=1) + expected = xr.DataArray(["foobarBAD"]).astype(dtype) assert_equal(result, expected) # case and flags provided to str.replace will have no effect # and will produce warnings - values = xr.DataArray(['fooBAD__barBAD__bad']).astype(dtype) - pat = re.compile(dtype('BAD[_]*')) + values = xr.DataArray(["fooBAD__barBAD__bad"]).astype(dtype) + pat = re.compile(dtype("BAD[_]*")) with pytest.raises(ValueError, match="case and flags cannot be"): - result = values.str.replace(pat, '', flags=re.IGNORECASE) + result = values.str.replace(pat, "", flags=re.IGNORECASE) with pytest.raises(ValueError, match="case and flags cannot be"): - result = values.str.replace(pat, '', case=False) + result = values.str.replace(pat, "", case=False) with pytest.raises(ValueError, match="case and flags cannot be"): - result = values.str.replace(pat, '', case=True) + result = values.str.replace(pat, "", case=True) # test with callable - values = xr.DataArray(['fooBAD__barBAD']).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) repl = lambda m: m.group(0).swapcase() - pat = re.compile(dtype('[a-z][A-Z]{2}')) + pat = re.compile(dtype("[a-z][A-Z]{2}")) result = values.str.replace(pat, repl, n=2) - expected = xr.DataArray(['foObaD__baRbaD']).astype(dtype) + expected = xr.DataArray(["foObaD__baRbaD"]).astype(dtype) assert_equal(result, expected) def test_replace_literal(dtype): # GH16808 literal replace (regex=False vs regex=True) - values = xr.DataArray(['f.o', 'foo']).astype(dtype) - expected = xr.DataArray(['bao', 'bao']).astype(dtype) - result = values.str.replace('f.', 'ba') + values = xr.DataArray(["f.o", "foo"]).astype(dtype) + expected = xr.DataArray(["bao", "bao"]).astype(dtype) + result = values.str.replace("f.", "ba") assert_equal(result, expected) - expected = xr.DataArray(['bao', 'foo']).astype(dtype) - result = values.str.replace('f.', 'ba', regex=False) + expected = xr.DataArray(["bao", "foo"]).astype(dtype) + result = values.str.replace("f.", "ba", regex=False) assert_equal(result, expected) # Cannot do a literal replace if given a callable repl or compiled # pattern callable_repl = lambda m: m.group(0).swapcase() - compiled_pat = re.compile('[a-z][A-Z]{2}') + compiled_pat = re.compile("[a-z][A-Z]{2}") msg = "Cannot use a callable replacement when regex=False" with pytest.raises(ValueError, match=msg): - values.str.replace('abc', callable_repl, regex=False) + values.str.replace("abc", callable_repl, regex=False) msg = "Cannot use a compiled regex as replacement pattern with regex=False" with pytest.raises(ValueError, match=msg): - values.str.replace(compiled_pat, '', regex=False) + values.str.replace(compiled_pat, "", regex=False) def test_repeat(dtype): - values = xr.DataArray(['a', 'b', 'c', 'd']).astype(dtype) + values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) result = values.str.repeat(3) - expected = xr.DataArray(['aaa', 'bbb', 'ccc', 'ddd']).astype(dtype) + expected = xr.DataArray(["aaa", "bbb", "ccc", "ddd"]).astype(dtype) assert_equal(result, expected) def test_match(dtype): # New match behavior introduced in 0.13 - values = xr.DataArray(['fooBAD__barBAD', 'foo']).astype(dtype) - result = values.str.match('.*(BAD[_]+).*(BAD)') + values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) + result = values.str.match(".*(BAD[_]+).*(BAD)") expected = xr.DataArray([True, False]) assert_equal(result, expected) - values = xr.DataArray(['fooBAD__barBAD', 'foo']).astype(dtype) - result = values.str.match('.*BAD[_]+.*BAD') + values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) + result = values.str.match(".*BAD[_]+.*BAD") expected = xr.DataArray([True, False]) assert_equal(result, expected) def test_empty_str_methods(): - empty = xr.DataArray(np.empty(shape=(0,), dtype='U')) + empty = xr.DataArray(np.empty(shape=(0,), dtype="U")) empty_str = empty empty_int = xr.DataArray(np.empty(shape=(0,), dtype=int)) empty_bool = xr.DataArray(np.empty(shape=(0,), dtype=bool)) - empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype='S')) + empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype="S")) assert_equal(empty_str, empty.str.title()) - assert_equal(empty_int, empty.str.count('a')) - assert_equal(empty_bool, empty.str.contains('a')) - assert_equal(empty_bool, empty.str.startswith('a')) - assert_equal(empty_bool, empty.str.endswith('a')) + assert_equal(empty_int, empty.str.count("a")) + assert_equal(empty_bool, empty.str.contains("a")) + assert_equal(empty_bool, empty.str.startswith("a")) + assert_equal(empty_bool, empty.str.endswith("a")) assert_equal(empty_str, empty.str.lower()) assert_equal(empty_str, empty.str.upper()) - assert_equal(empty_str, empty.str.replace('a', 'b')) + assert_equal(empty_str, empty.str.replace("a", "b")) assert_equal(empty_str, empty.str.repeat(3)) - assert_equal(empty_bool, empty.str.match('^a')) + assert_equal(empty_bool, empty.str.match("^a")) assert_equal(empty_int, empty.str.len()) - assert_equal(empty_int, empty.str.find('a')) - assert_equal(empty_int, empty.str.rfind('a')) + assert_equal(empty_int, empty.str.find("a")) + assert_equal(empty_int, empty.str.rfind("a")) assert_equal(empty_str, empty.str.pad(42)) assert_equal(empty_str, empty.str.center(42)) assert_equal(empty_str, empty.str.slice(stop=1)) @@ -269,8 +273,8 @@ def test_empty_str_methods(): assert_equal(empty_str, empty.str.rstrip()) assert_equal(empty_str, empty.str.wrap(42)) assert_equal(empty_str, empty.str.get(0)) - assert_equal(empty_str, empty_bytes.str.decode('ascii')) - assert_equal(empty_bytes, empty.str.encode('ascii')) + assert_equal(empty_str, empty_bytes.str.decode("ascii")) + assert_equal(empty_bytes, empty.str.encode("ascii")) assert_equal(empty_str, empty.str.isalnum()) assert_equal(empty_str, empty.str.isalpha()) assert_equal(empty_str, empty.str.isdigit()) @@ -282,26 +286,20 @@ def test_empty_str_methods(): assert_equal(empty_str, empty.str.isdecimal()) assert_equal(empty_str, empty.str.capitalize()) assert_equal(empty_str, empty.str.swapcase()) - table = str.maketrans('a', 'b') + table = str.maketrans("a", "b") assert_equal(empty_str, empty.str.translate(table)) def test_ismethods(dtype): - values = ['A', 'b', 'Xy', '4', '3A', '', 'TT', '55', '-', ' '] + values = ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "] str_s = xr.DataArray(values).astype(dtype) alnum_e = [True, True, True, True, True, False, True, True, False, False] - alpha_e = [True, True, True, False, False, False, True, False, False, - False] - digit_e = [False, False, False, True, False, False, False, True, False, - False] - space_e = [False, False, False, False, False, False, False, False, - False, True] - lower_e = [False, True, False, False, False, False, False, False, - False, False] - upper_e = [True, False, False, False, True, False, True, False, False, - False] - title_e = [True, False, True, False, True, False, False, False, False, - False] + alpha_e = [True, True, True, False, False, False, True, False, False, False] + digit_e = [False, False, False, True, False, False, False, True, False, False] + space_e = [False, False, False, False, False, False, False, False, False, True] + lower_e = [False, True, False, False, False, False, False, False, False, False] + upper_e = [True, False, False, False, True, False, True, False, False, False] + title_e = [True, False, True, False, True, False, False, False, False, False] assert_equal(str_s.str.isalnum(), xr.DataArray(alnum_e)) assert_equal(str_s.str.isalpha(), xr.DataArray(alpha_e)) @@ -317,7 +315,7 @@ def test_isnumeric(): # 0x2605: ★ not number # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY # 0xFF13: 3 Em 3 - values = ['A', '3', '¼', '★', '፸', '3', 'four'] + values = ["A", "3", "¼", "★", "፸", "3", "four"] s = xr.DataArray(values) numeric_e = [False, True, True, False, True, True, False] decimal_e = [False, True, False, False, False, True, False] @@ -326,144 +324,143 @@ def test_isnumeric(): def test_len(dtype): - values = ['foo', 'fooo', 'fooooo', 'fooooooo'] + values = ["foo", "fooo", "fooooo", "fooooooo"] result = xr.DataArray(values).astype(dtype).str.len() expected = xr.DataArray([len(x) for x in values]) assert_equal(result, expected) def test_find(dtype): - values = xr.DataArray(['ABCDEFG', 'BCDEFEF', 'DEFGHIJEF', 'EFGHEF', 'XXX']) + values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) values = values.astype(dtype) - result = values.str.find('EF') + result = values.str.find("EF") assert_equal(result, xr.DataArray([4, 3, 1, 0, -1])) - expected = xr.DataArray([v.find(dtype('EF')) for v in values.values]) + expected = xr.DataArray([v.find(dtype("EF")) for v in values.values]) assert_equal(result, expected) - result = values.str.rfind('EF') + result = values.str.rfind("EF") assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) - expected = xr.DataArray([v.rfind(dtype('EF')) for v in values.values]) + expected = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) assert_equal(result, expected) - result = values.str.find('EF', 3) + result = values.str.find("EF", 3) assert_equal(result, xr.DataArray([4, 3, 7, 4, -1])) - expected = xr.DataArray([v.find(dtype('EF'), 3) for v in values.values]) + expected = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) assert_equal(result, expected) - result = values.str.rfind('EF', 3) + result = values.str.rfind("EF", 3) assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) - expected = xr.DataArray([v.rfind(dtype('EF'), 3) for v in values.values]) + expected = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) assert_equal(result, expected) - result = values.str.find('EF', 3, 6) + result = values.str.find("EF", 3, 6) assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) - expected = xr.DataArray([v.find(dtype('EF'), 3, 6) for v in values.values]) + expected = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) assert_equal(result, expected) - result = values.str.rfind('EF', 3, 6) + result = values.str.rfind("EF", 3, 6) assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) - xp = xr.DataArray([v.rfind(dtype('EF'), 3, 6) for v in values.values]) + xp = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) assert_equal(result, xp) def test_index(dtype): - s = xr.DataArray(['ABCDEFG', 'BCDEFEF', 'DEFGHIJEF', - 'EFGHEF']).astype(dtype) + s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) - result = s.str.index('EF') + result = s.str.index("EF") assert_equal(result, xr.DataArray([4, 3, 1, 0])) - result = s.str.rindex('EF') + result = s.str.rindex("EF") assert_equal(result, xr.DataArray([4, 5, 7, 4])) - result = s.str.index('EF', 3) + result = s.str.index("EF", 3) assert_equal(result, xr.DataArray([4, 3, 7, 4])) - result = s.str.rindex('EF', 3) + result = s.str.rindex("EF", 3) assert_equal(result, xr.DataArray([4, 5, 7, 4])) - result = s.str.index('E', 4, 8) + result = s.str.index("E", 4, 8) assert_equal(result, xr.DataArray([4, 5, 7, 4])) - result = s.str.rindex('E', 0, 5) + result = s.str.rindex("E", 0, 5) assert_equal(result, xr.DataArray([4, 3, 1, 4])) with pytest.raises(ValueError): - result = s.str.index('DE') + result = s.str.index("DE") def test_pad(dtype): - values = xr.DataArray(['a', 'b', 'c', 'eeeee']).astype(dtype) + values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) - result = values.str.pad(5, side='left') - expected = xr.DataArray([' a', ' b', ' c', 'eeeee']).astype(dtype) + result = values.str.pad(5, side="left") + expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) assert_equal(result, expected) - result = values.str.pad(5, side='right') - expected = xr.DataArray(['a ', 'b ', 'c ', 'eeeee']).astype(dtype) + result = values.str.pad(5, side="right") + expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) assert_equal(result, expected) - result = values.str.pad(5, side='both') - expected = xr.DataArray([' a ', ' b ', ' c ', 'eeeee']).astype(dtype) + result = values.str.pad(5, side="both") + expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) assert_equal(result, expected) def test_pad_fillchar(dtype): - values = xr.DataArray(['a', 'b', 'c', 'eeeee']).astype(dtype) + values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) - result = values.str.pad(5, side='left', fillchar='X') - expected = xr.DataArray(['XXXXa', 'XXXXb', 'XXXXc', 'eeeee']).astype(dtype) + result = values.str.pad(5, side="left", fillchar="X") + expected = xr.DataArray(["XXXXa", "XXXXb", "XXXXc", "eeeee"]).astype(dtype) assert_equal(result, expected) - result = values.str.pad(5, side='right', fillchar='X') - expected = xr.DataArray(['aXXXX', 'bXXXX', 'cXXXX', 'eeeee']).astype(dtype) + result = values.str.pad(5, side="right", fillchar="X") + expected = xr.DataArray(["aXXXX", "bXXXX", "cXXXX", "eeeee"]).astype(dtype) assert_equal(result, expected) - result = values.str.pad(5, side='both', fillchar='X') - expected = xr.DataArray(['XXaXX', 'XXbXX', 'XXcXX', 'eeeee']).astype(dtype) + result = values.str.pad(5, side="both", fillchar="X") + expected = xr.DataArray(["XXaXX", "XXbXX", "XXcXX", "eeeee"]).astype(dtype) assert_equal(result, expected) msg = "fillchar must be a character, not str" with pytest.raises(TypeError, match=msg): - result = values.str.pad(5, fillchar='XY') + result = values.str.pad(5, fillchar="XY") def test_translate(): - values = xr.DataArray(['abcdefg', 'abcc', 'cdddfg', 'cdefggg']) - table = str.maketrans('abc', 'cde') + values = xr.DataArray(["abcdefg", "abcc", "cdddfg", "cdefggg"]) + table = str.maketrans("abc", "cde") result = values.str.translate(table) - expected = xr.DataArray(['cdedefg', 'cdee', 'edddfg', 'edefggg']) + expected = xr.DataArray(["cdedefg", "cdee", "edddfg", "edefggg"]) assert_equal(result, expected) def test_center_ljust_rjust(dtype): - values = xr.DataArray(['a', 'b', 'c', 'eeeee']).astype(dtype) + values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) result = values.str.center(5) - expected = xr.DataArray([' a ', ' b ', ' c ', 'eeeee']).astype(dtype) + expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) assert_equal(result, expected) result = values.str.ljust(5) - expected = xr.DataArray(['a ', 'b ', 'c ', 'eeeee']).astype(dtype) + expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) assert_equal(result, expected) result = values.str.rjust(5) - expected = xr.DataArray([' a', ' b', ' c', 'eeeee']).astype(dtype) + expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) assert_equal(result, expected) def test_center_ljust_rjust_fillchar(dtype): - values = xr.DataArray(['a', 'bb', 'cccc', 'ddddd', 'eeeeee']).astype(dtype) - result = values.str.center(5, fillchar='X') - expected = xr.DataArray(['XXaXX', 'XXbbX', 'Xcccc', 'ddddd', 'eeeeee']) + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) + result = values.str.center(5, fillchar="X") + expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]) assert_equal(result, expected.astype(dtype)) - result = values.str.ljust(5, fillchar='X') - expected = xr.DataArray(['aXXXX', 'bbXXX', 'ccccX', 'ddddd', 'eeeeee']) + result = values.str.ljust(5, fillchar="X") + expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]) assert_equal(result, expected.astype(dtype)) - result = values.str.rjust(5, fillchar='X') - expected = xr.DataArray(['XXXXa', 'XXXbb', 'Xcccc', 'ddddd', 'eeeeee']) + result = values.str.rjust(5, fillchar="X") + expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]) assert_equal(result, expected.astype(dtype)) # If fillchar is not a charatter, normal str raises TypeError @@ -472,111 +469,110 @@ def test_center_ljust_rjust_fillchar(dtype): template = "fillchar must be a character, not {dtype}" with pytest.raises(TypeError, match=template.format(dtype="str")): - values.str.center(5, fillchar='XY') + values.str.center(5, fillchar="XY") with pytest.raises(TypeError, match=template.format(dtype="str")): - values.str.ljust(5, fillchar='XY') + values.str.ljust(5, fillchar="XY") with pytest.raises(TypeError, match=template.format(dtype="str")): - values.str.rjust(5, fillchar='XY') + values.str.rjust(5, fillchar="XY") def test_zfill(dtype): - values = xr.DataArray(['1', '22', 'aaa', '333', '45678']).astype(dtype) + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) result = values.str.zfill(5) - expected = xr.DataArray(['00001', '00022', '00aaa', '00333', '45678']) + expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]) assert_equal(result, expected.astype(dtype)) result = values.str.zfill(3) - expected = xr.DataArray(['001', '022', 'aaa', '333', '45678']) + expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]) assert_equal(result, expected.astype(dtype)) def test_slice(dtype): - arr = xr.DataArray(['aafootwo', 'aabartwo', 'aabazqux']).astype(dtype) + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) result = arr.str.slice(2, 5) - exp = xr.DataArray(['foo', 'bar', 'baz']).astype(dtype) + exp = xr.DataArray(["foo", "bar", "baz"]).astype(dtype) assert_equal(result, exp) - for start, stop, step in [(0, 3, -1), (None, None, -1), - (3, 10, 2), (3, 0, -1)]: + for start, stop, step in [(0, 3, -1), (None, None, -1), (3, 10, 2), (3, 0, -1)]: try: result = arr.str[start:stop:step] expected = xr.DataArray([s[start:stop:step] for s in arr.values]) assert_equal(result, expected.astype(dtype)) except IndexError: - print('failed on %s:%s:%s' % (start, stop, step)) + print("failed on %s:%s:%s" % (start, stop, step)) raise def test_slice_replace(dtype): da = lambda x: xr.DataArray(x).astype(dtype) - values = da(['short', 'a bit longer', 'evenlongerthanthat', '']) + values = da(["short", "a bit longer", "evenlongerthanthat", ""]) - expected = da(['shrt', 'a it longer', 'evnlongerthanthat', '']) + expected = da(["shrt", "a it longer", "evnlongerthanthat", ""]) result = values.str.slice_replace(2, 3) assert_equal(result, expected) - expected = da(['shzrt', 'a zit longer', 'evznlongerthanthat', 'z']) - result = values.str.slice_replace(2, 3, 'z') + expected = da(["shzrt", "a zit longer", "evznlongerthanthat", "z"]) + result = values.str.slice_replace(2, 3, "z") assert_equal(result, expected) - expected = da(['shzort', 'a zbit longer', 'evzenlongerthanthat', 'z']) - result = values.str.slice_replace(2, 2, 'z') + expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) + result = values.str.slice_replace(2, 2, "z") assert_equal(result, expected) - expected = da(['shzort', 'a zbit longer', 'evzenlongerthanthat', 'z']) - result = values.str.slice_replace(2, 1, 'z') + expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) + result = values.str.slice_replace(2, 1, "z") assert_equal(result, expected) - expected = da(['shorz', 'a bit longez', 'evenlongerthanthaz', 'z']) - result = values.str.slice_replace(-1, None, 'z') + expected = da(["shorz", "a bit longez", "evenlongerthanthaz", "z"]) + result = values.str.slice_replace(-1, None, "z") assert_equal(result, expected) - expected = da(['zrt', 'zer', 'zat', 'z']) - result = values.str.slice_replace(None, -2, 'z') + expected = da(["zrt", "zer", "zat", "z"]) + result = values.str.slice_replace(None, -2, "z") assert_equal(result, expected) - expected = da(['shortz', 'a bit znger', 'evenlozerthanthat', 'z']) - result = values.str.slice_replace(6, 8, 'z') + expected = da(["shortz", "a bit znger", "evenlozerthanthat", "z"]) + result = values.str.slice_replace(6, 8, "z") assert_equal(result, expected) - expected = da(['zrt', 'a zit longer', 'evenlongzerthanthat', 'z']) - result = values.str.slice_replace(-10, 3, 'z') + expected = da(["zrt", "a zit longer", "evenlongzerthanthat", "z"]) + result = values.str.slice_replace(-10, 3, "z") assert_equal(result, expected) def test_strip_lstrip_rstrip(dtype): - values = xr.DataArray([' aa ', ' bb \n', 'cc ']).astype(dtype) + values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) result = values.str.strip() - expected = xr.DataArray(['aa', 'bb', 'cc']).astype(dtype) + expected = xr.DataArray(["aa", "bb", "cc"]).astype(dtype) assert_equal(result, expected) result = values.str.lstrip() - expected = xr.DataArray(['aa ', 'bb \n', 'cc ']).astype(dtype) + expected = xr.DataArray(["aa ", "bb \n", "cc "]).astype(dtype) assert_equal(result, expected) result = values.str.rstrip() - expected = xr.DataArray([' aa', ' bb', 'cc']).astype(dtype) + expected = xr.DataArray([" aa", " bb", "cc"]).astype(dtype) assert_equal(result, expected) def test_strip_lstrip_rstrip_args(dtype): - values = xr.DataArray(['xxABCxx', 'xx BNSD', 'LDFJH xx']).astype(dtype) + values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) - rs = values.str.strip('x') - xp = xr.DataArray(['ABC', ' BNSD', 'LDFJH ']).astype(dtype) + rs = values.str.strip("x") + xp = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) assert_equal(rs, xp) - rs = values.str.lstrip('x') - xp = xr.DataArray(['ABCxx', ' BNSD', 'LDFJH xx']).astype(dtype) + rs = values.str.lstrip("x") + xp = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) assert_equal(rs, xp) - rs = values.str.rstrip('x') - xp = xr.DataArray(['xxABC', 'xx BNSD', 'LDFJH ']).astype(dtype) + rs = values.str.rstrip("x") + xp = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) assert_equal(rs, xp) @@ -585,75 +581,99 @@ def test_wrap(): # two words greater than width, one word less than width, one word # equal to width, one word greater than width, multiple tokens with # trailing whitespace equal to width - values = xr.DataArray(['hello world', 'hello world!', 'hello world!!', - 'abcdefabcde', 'abcdefabcdef', 'abcdefabcdefa', - 'ab ab ab ab ', 'ab ab ab ab a', '\t']) + values = xr.DataArray( + [ + "hello world", + "hello world!", + "hello world!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdefa", + "ab ab ab ab ", + "ab ab ab ab a", + "\t", + ] + ) # expected values - xp = xr.DataArray(['hello world', 'hello world!', 'hello\nworld!!', - 'abcdefabcde', 'abcdefabcdef', 'abcdefabcdef\na', - 'ab ab ab ab', 'ab ab ab ab\na', '']) + xp = xr.DataArray( + [ + "hello world", + "hello world!", + "hello\nworld!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdef\na", + "ab ab ab ab", + "ab ab ab ab\na", + "", + ] + ) rs = values.str.wrap(12, break_long_words=True) assert_equal(rs, xp) # test with pre and post whitespace (non-unicode), NaN, and non-ascii # Unicode - values = xr.DataArray([' pre ', '\xac\u20ac\U00008000 abadcafe']) - xp = xr.DataArray([' pre', '\xac\u20ac\U00008000 ab\nadcafe']) + values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"]) + xp = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) rs = values.str.wrap(6) assert_equal(rs, xp) def test_get(dtype): - values = xr.DataArray(['a_b_c', 'c_d_e', 'f_g_h']).astype(dtype) + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"]).astype(dtype) result = values.str[2] - expected = xr.DataArray(['b', 'd', 'g']).astype(dtype) + expected = xr.DataArray(["b", "d", "g"]).astype(dtype) assert_equal(result, expected) # bounds testing - values = xr.DataArray(['1_2_3_4_5', '6_7_8_9_10', '11_12']).astype(dtype) + values = xr.DataArray(["1_2_3_4_5", "6_7_8_9_10", "11_12"]).astype(dtype) # positive index result = values.str[5] - expected = xr.DataArray(['_', '_', '']).astype(dtype) + expected = xr.DataArray(["_", "_", ""]).astype(dtype) assert_equal(result, expected) # negative index result = values.str[-6] - expected = xr.DataArray(['_', '8', '']).astype(dtype) + expected = xr.DataArray(["_", "8", ""]).astype(dtype) assert_equal(result, expected) def test_encode_decode(): - data = xr.DataArray(['a', 'b', 'a\xe4']) - encoded = data.str.encode('utf-8') - decoded = encoded.str.decode('utf-8') + data = xr.DataArray(["a", "b", "a\xe4"]) + encoded = data.str.encode("utf-8") + decoded = encoded.str.decode("utf-8") assert_equal(data, decoded) def test_encode_decode_errors(): - encodeBase = xr.DataArray(['a', 'b', 'a\x9d']) + encodeBase = xr.DataArray(["a", "b", "a\x9d"]) - msg = (r"'charmap' codec can't encode character '\\x9d' in position 1:" - " character maps to ") + msg = ( + r"'charmap' codec can't encode character '\\x9d' in position 1:" + " character maps to " + ) with pytest.raises(UnicodeEncodeError, match=msg): - encodeBase.str.encode('cp1252') + encodeBase.str.encode("cp1252") - f = lambda x: x.encode('cp1252', 'ignore') - result = encodeBase.str.encode('cp1252', 'ignore') + f = lambda x: x.encode("cp1252", "ignore") + result = encodeBase.str.encode("cp1252", "ignore") expected = xr.DataArray([f(x) for x in encodeBase.values.tolist()]) assert_equal(result, expected) - decodeBase = xr.DataArray([b'a', b'b', b'a\x9d']) + decodeBase = xr.DataArray([b"a", b"b", b"a\x9d"]) - msg = ("'charmap' codec can't decode byte 0x9d in position 1:" - " character maps to ") + msg = ( + "'charmap' codec can't decode byte 0x9d in position 1:" + " character maps to " + ) with pytest.raises(UnicodeDecodeError, match=msg): - decodeBase.str.decode('cp1252') + decodeBase.str.decode("cp1252") - f = lambda x: x.decode('cp1252', 'ignore') - result = decodeBase.str.decode('cp1252', 'ignore') + f = lambda x: x.decode("cp1252", "ignore") + result = decodeBase.str.decode("cp1252", "ignore") expected = xr.DataArray([f(x) for x in decodeBase.values.tolist()]) assert_equal(result, expected) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 92f516b8c3b..83ff832f7fd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -18,8 +18,16 @@ import xarray as xr from xarray import ( - DataArray, Dataset, backends, load_dataarray, load_dataset, open_dataarray, - open_dataset, open_mfdataset, save_mfdataset) + DataArray, + Dataset, + backends, + load_dataarray, + load_dataset, + open_dataarray, + open_dataset, + open_mfdataset, + save_mfdataset, +) from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore @@ -30,14 +38,36 @@ from xarray.tests import mock from . import ( - assert_allclose, assert_array_equal, assert_equal, assert_identical, - has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cfgrib, - requires_cftime, requires_dask, requires_h5fileobj, requires_h5netcdf, - requires_netCDF4, requires_pathlib, requires_pseudonetcdf, requires_pydap, - requires_pynio, requires_rasterio, requires_scipy, - requires_scipy_or_netCDF4, requires_zarr, arm_xfail) + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + has_dask, + has_netCDF4, + has_scipy, + network, + raises_regex, + requires_cfgrib, + requires_cftime, + requires_dask, + requires_h5fileobj, + requires_h5netcdf, + requires_netCDF4, + requires_pathlib, + requires_pseudonetcdf, + requires_pydap, + requires_pynio, + requires_rasterio, + requires_scipy, + requires_scipy_or_netCDF4, + requires_zarr, + arm_xfail, +) from .test_coding_times import ( - _ALL_CALENDARS, _NON_STANDARD_CALENDARS, _STANDARD_CALENDARS) + _ALL_CALENDARS, + _NON_STANDARD_CALENDARS, + _STANDARD_CALENDARS, +) from .test_dataset import create_test_data, create_append_test_data try: @@ -57,93 +87,124 @@ from pandas.tslib import OutOfBoundsDatetime -ON_WINDOWS = sys.platform == 'win32' +ON_WINDOWS = sys.platform == "win32" def open_example_dataset(name, *args, **kwargs): - return open_dataset(os.path.join(os.path.dirname(__file__), 'data', name), - *args, **kwargs) + return open_dataset( + os.path.join(os.path.dirname(__file__), "data", name), *args, **kwargs + ) def open_example_mfdataset(names, *args, **kwargs): return open_mfdataset( - [os.path.join(os.path.dirname(__file__), 'data', name) - for name in names], - *args, **kwargs) + [os.path.join(os.path.dirname(__file__), "data", name) for name in names], + *args, + **kwargs + ) def create_masked_and_scaled_data(): x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=np.float32) - encoding = {'_FillValue': -1, 'add_offset': 10, - 'scale_factor': np.float32(0.1), 'dtype': 'i2'} - return Dataset({'x': ('t', x, {}, encoding)}) + encoding = { + "_FillValue": -1, + "add_offset": 10, + "scale_factor": np.float32(0.1), + "dtype": "i2", + } + return Dataset({"x": ("t", x, {}, encoding)}) def create_encoded_masked_and_scaled_data(): - attributes = {'_FillValue': -1, 'add_offset': 10, - 'scale_factor': np.float32(0.1)} - return Dataset({'x': ('t', np.int16([-1, -1, 0, 1, 2]), attributes)}) + attributes = {"_FillValue": -1, "add_offset": 10, "scale_factor": np.float32(0.1)} + return Dataset({"x": ("t", np.int16([-1, -1, 0, 1, 2]), attributes)}) def create_unsigned_masked_scaled_data(): - encoding = {'_FillValue': 255, '_Unsigned': 'true', 'dtype': 'i1', - 'add_offset': 10, 'scale_factor': np.float32(0.1)} + encoding = { + "_FillValue": 255, + "_Unsigned": "true", + "dtype": "i1", + "add_offset": 10, + "scale_factor": np.float32(0.1), + } x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=np.float32) - return Dataset({'x': ('t', x, {}, encoding)}) + return Dataset({"x": ("t", x, {}, encoding)}) def create_encoded_unsigned_masked_scaled_data(): # These are values as written to the file: the _FillValue will # be represented in the signed form. - attributes = {'_FillValue': -1, '_Unsigned': 'true', - 'add_offset': 10, 'scale_factor': np.float32(0.1)} + attributes = { + "_FillValue": -1, + "_Unsigned": "true", + "add_offset": 10, + "scale_factor": np.float32(0.1), + } # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned - sb = np.asarray([0, 1, 127, -128, -1], dtype='i1') - return Dataset({'x': ('t', sb, attributes)}) + sb = np.asarray([0, 1, 127, -128, -1], dtype="i1") + return Dataset({"x": ("t", sb, attributes)}) def create_bad_unsigned_masked_scaled_data(): - encoding = {'_FillValue': 255, '_Unsigned': True, 'dtype': 'i1', - 'add_offset': 10, 'scale_factor': np.float32(0.1)} + encoding = { + "_FillValue": 255, + "_Unsigned": True, + "dtype": "i1", + "add_offset": 10, + "scale_factor": np.float32(0.1), + } x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=np.float32) - return Dataset({'x': ('t', x, {}, encoding)}) + return Dataset({"x": ("t", x, {}, encoding)}) def create_bad_encoded_unsigned_masked_scaled_data(): # These are values as written to the file: the _FillValue will # be represented in the signed form. - attributes = {'_FillValue': -1, '_Unsigned': True, - 'add_offset': 10, 'scale_factor': np.float32(0.1)} + attributes = { + "_FillValue": -1, + "_Unsigned": True, + "add_offset": 10, + "scale_factor": np.float32(0.1), + } # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned - sb = np.asarray([0, 1, 127, -128, -1], dtype='i1') - return Dataset({'x': ('t', sb, attributes)}) + sb = np.asarray([0, 1, 127, -128, -1], dtype="i1") + return Dataset({"x": ("t", sb, attributes)}) def create_signed_masked_scaled_data(): - encoding = {'_FillValue': -127, '_Unsigned': 'false', 'dtype': 'i1', - 'add_offset': 10, 'scale_factor': np.float32(0.1)} + encoding = { + "_FillValue": -127, + "_Unsigned": "false", + "dtype": "i1", + "add_offset": 10, + "scale_factor": np.float32(0.1), + } x = np.array([-1.0, 10.1, 22.7, np.nan], dtype=np.float32) - return Dataset({'x': ('t', x, {}, encoding)}) + return Dataset({"x": ("t", x, {}, encoding)}) def create_encoded_signed_masked_scaled_data(): # These are values as written to the file: the _FillValue will # be represented in the signed form. - attributes = {'_FillValue': -127, '_Unsigned': 'false', - 'add_offset': 10, 'scale_factor': np.float32(0.1)} + attributes = { + "_FillValue": -127, + "_Unsigned": "false", + "add_offset": 10, + "scale_factor": np.float32(0.1), + } # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned - sb = np.asarray([-110, 1, 127, -127], dtype='i1') - return Dataset({'x': ('t', sb, attributes)}) + sb = np.asarray([-110, 1, 127, -127], dtype="i1") + return Dataset({"x": ("t", sb, attributes)}) def create_boolean_data(): - attributes = {'units': '-'} - return Dataset({'x': ('t', [True, False, False, True], attributes)}) + attributes = {"units": "-"} + return Dataset({"x": ("t", [True, False, False, True], attributes)}) class TestCommon: def test_robust_getitem(self): - class UnreliableArrayFailure(Exception): pass @@ -163,8 +224,7 @@ def __getitem__(self, key): array[0] assert array[0] == 0 - actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, - initial_delay=0) + actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0) assert actual == 0 @@ -180,37 +240,38 @@ def create_store(self): raise NotImplementedError @contextlib.contextmanager - def roundtrip(self, data, save_kwargs=None, open_kwargs=None, - allow_cleanup_failure=False): + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): if save_kwargs is None: save_kwargs = {} if open_kwargs is None: open_kwargs = {} - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as path: + with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: self.save(data, path, **save_kwargs) with self.open(path, **open_kwargs) as ds: yield ds @contextlib.contextmanager - def roundtrip_append(self, data, save_kwargs=None, open_kwargs=None, - allow_cleanup_failure=False): + def roundtrip_append( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): if save_kwargs is None: save_kwargs = {} if open_kwargs is None: open_kwargs = {} - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as path: + with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: for i, key in enumerate(data.variables): - mode = 'a' if i > 0 else 'w' + mode = "a" if i > 0 else "w" self.save(data[[key]], path, mode=mode, **save_kwargs) with self.open(path, **open_kwargs) as ds: yield ds # The save/open methods may be overwritten below def save(self, dataset, path, **kwargs): - return dataset.to_netcdf(path, engine=self.engine, - format=self.file_format, **kwargs) + return dataset.to_netcdf( + path, engine=self.engine, format=self.file_format, **kwargs + ) @contextlib.contextmanager def open(self, path, **kwargs): @@ -219,9 +280,9 @@ def open(self, path, **kwargs): def test_zero_dimensional_variable(self): expected = create_test_data() - expected['float_var'] = ([], 1.0e9, {'units': 'units of awesome'}) - expected['bytes_var'] = ([], b'foobar') - expected['string_var'] = ([], 'foobar') + expected["float_var"] = ([], 1.0e9, {"units": "units of awesome"}) + expected["bytes_var"] = ([], b"foobar") + expected["string_var"] = ([], "foobar") with self.roundtrip(expected) as actual: assert_identical(expected, actual) @@ -237,15 +298,16 @@ def test_write_store(self): def check_dtypes_roundtripped(self, expected, actual): for k in expected.variables: expected_dtype = expected.variables[k].dtype - if (isinstance(self, NetCDF3Only) and expected_dtype == 'int64'): + if isinstance(self, NetCDF3Only) and expected_dtype == "int64": # downcast - expected_dtype = np.dtype('int32') + expected_dtype = np.dtype("int32") actual_dtype = actual.variables[k].dtype # TODO: check expected behavior for string dtypes more carefully - string_kinds = {'O', 'S', 'U'} - assert (expected_dtype == actual_dtype - or (expected_dtype.kind in string_kinds and - actual_dtype.kind in string_kinds)) + string_kinds = {"O", "S", "U"} + assert expected_dtype == actual_dtype or ( + expected_dtype.kind in string_kinds + and actual_dtype.kind in string_kinds + ) def test_roundtrip_test_data(self): expected = create_test_data() @@ -278,8 +340,8 @@ def assert_loads(vars=None): with assert_loads() as ds: ds.load() - with assert_loads(['var1', 'dim1', 'dim2']) as ds: - ds['var1'].load() + with assert_loads(["var1", "dim1", "dim2"]) as ds: + ds["var1"].load() # verify we can read data even after closing the file with self.roundtrip(expected) as ds: @@ -307,10 +369,9 @@ def test_dataset_compute(self): def test_pickle(self): if not has_dask: - pytest.xfail('pickling requires dask for SerializableLock') - expected = Dataset({'foo': ('x', [42])}) - with self.roundtrip( - expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: + pytest.xfail("pickling requires dask for SerializableLock") + expected = Dataset({"foo": ("x", [42])}) + with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: # Windows doesn't like reopening an already open file raw_pickle = pickle.dumps(roundtripped) @@ -320,54 +381,55 @@ def test_pickle(self): @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_pickle_dataarray(self): if not has_dask: - pytest.xfail('pickling requires dask for SerializableLock') - expected = Dataset({'foo': ('x', [42])}) - with self.roundtrip( - expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: + pytest.xfail("pickling requires dask for SerializableLock") + expected = Dataset({"foo": ("x", [42])}) + with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: - raw_pickle = pickle.dumps(roundtripped['foo']) + raw_pickle = pickle.dumps(roundtripped["foo"]) # TODO: figure out how to explicitly close the file for the # unpickled DataArray? unpickled = pickle.loads(raw_pickle) - assert_identical(expected['foo'], unpickled) + assert_identical(expected["foo"], unpickled) def test_dataset_caching(self): - expected = Dataset({'foo': ('x', [5, 6, 7])}) + expected = Dataset({"foo": ("x", [5, 6, 7])}) with self.roundtrip(expected) as actual: - assert isinstance(actual.foo.variable._data, - indexing.MemoryCachedArray) + assert isinstance(actual.foo.variable._data, indexing.MemoryCachedArray) assert not actual.foo.variable._in_memory actual.foo.values # cache assert actual.foo.variable._in_memory - with self.roundtrip(expected, open_kwargs={'cache': False}) as actual: - assert isinstance(actual.foo.variable._data, - indexing.CopyOnWriteArray) + with self.roundtrip(expected, open_kwargs={"cache": False}) as actual: + assert isinstance(actual.foo.variable._data, indexing.CopyOnWriteArray) assert not actual.foo.variable._in_memory actual.foo.values # no caching assert not actual.foo.variable._in_memory def test_roundtrip_None_variable(self): - expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])}) + expected = Dataset({None: (("x", "y"), [[0, 1], [2, 3]])}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) def test_roundtrip_object_dtype(self): floats = np.array([0.0, 0.0, 1.0, 2.0, 3.0], dtype=object) floats_nans = np.array([np.nan, np.nan, 1.0, 2.0, 3.0], dtype=object) - bytes_ = np.array([b'ab', b'cdef', b'g'], dtype=object) - bytes_nans = np.array([b'ab', b'cdef', np.nan], dtype=object) - strings = np.array(['ab', 'cdef', 'g'], dtype=object) - strings_nans = np.array(['ab', 'cdef', np.nan], dtype=object) + bytes_ = np.array([b"ab", b"cdef", b"g"], dtype=object) + bytes_nans = np.array([b"ab", b"cdef", np.nan], dtype=object) + strings = np.array(["ab", "cdef", "g"], dtype=object) + strings_nans = np.array(["ab", "cdef", np.nan], dtype=object) all_nans = np.array([np.nan, np.nan], dtype=object) - original = Dataset({'floats': ('a', floats), - 'floats_nans': ('a', floats_nans), - 'bytes': ('b', bytes_), - 'bytes_nans': ('b', bytes_nans), - 'strings': ('b', strings), - 'strings_nans': ('b', strings_nans), - 'all_nans': ('c', all_nans), - 'nan': ([], np.nan)}) + original = Dataset( + { + "floats": ("a", floats), + "floats_nans": ("a", floats_nans), + "bytes": ("b", bytes_), + "bytes_nans": ("b", bytes_nans), + "strings": ("b", strings), + "strings_nans": ("b", strings_nans), + "all_nans": ("c", all_nans), + "nan": ([], np.nan), + } + ) expected = original.copy(deep=True) with self.roundtrip(original) as actual: try: @@ -379,35 +441,35 @@ def test_roundtrip_object_dtype(self): # This currently includes all netCDF files when encoding is not # explicitly set. # https://github.com/pydata/xarray/issues/1647 - expected['bytes_nans'][-1] = b'' - expected['strings_nans'][-1] = '' + expected["bytes_nans"][-1] = b"" + expected["strings_nans"][-1] = "" assert_identical(expected, actual) def test_roundtrip_string_data(self): - expected = Dataset({'x': ('t', ['ab', 'cdef'])}) + expected = Dataset({"x": ("t", ["ab", "cdef"])}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) def test_roundtrip_string_encoded_characters(self): - expected = Dataset({'x': ('t', ['ab', 'cdef'])}) - expected['x'].encoding['dtype'] = 'S1' + expected = Dataset({"x": ("t", ["ab", "cdef"])}) + expected["x"].encoding["dtype"] = "S1" with self.roundtrip(expected) as actual: assert_identical(expected, actual) - assert actual['x'].encoding['_Encoding'] == 'utf-8' + assert actual["x"].encoding["_Encoding"] == "utf-8" - expected['x'].encoding['_Encoding'] = 'ascii' + expected["x"].encoding["_Encoding"] = "ascii" with self.roundtrip(expected) as actual: assert_identical(expected, actual) - assert actual['x'].encoding['_Encoding'] == 'ascii' + assert actual["x"].encoding["_Encoding"] == "ascii" @arm_xfail def test_roundtrip_numpy_datetime_data(self): - times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) - expected = Dataset({'t': ('t', times), 't0': times[0]}) - kwargs = {'encoding': {'t0': {'units': 'days since 1950-01-01'}}} + times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"]) + expected = Dataset({"t": ("t", times), "t0": times[0]}) + kwargs = {"encoding": {"t0": {"units": "days since 1950-01-01"}}} with self.roundtrip(expected, save_kwargs=kwargs) as actual: assert_identical(expected, actual) - assert actual.t0.encoding['units'] == 'days since 1950-01-01' + assert actual.t0.encoding["units"] == "days since 1950-01-01" @requires_cftime def test_roundtrip_cftime_datetime_data(self): @@ -416,46 +478,43 @@ def test_roundtrip_cftime_datetime_data(self): date_types = _all_cftime_date_types() for date_type in date_types.values(): times = [date_type(1, 1, 1), date_type(1, 1, 2)] - expected = Dataset({'t': ('t', times), 't0': times[0]}) - kwargs = {'encoding': {'t0': {'units': 'days since 0001-01-01'}}} + expected = Dataset({"t": ("t", times), "t0": times[0]}) + kwargs = {"encoding": {"t0": {"units": "days since 0001-01-01"}}} expected_decoded_t = np.array(times) expected_decoded_t0 = np.array([date_type(1, 1, 1)]) expected_calendar = times[0].calendar with warnings.catch_warnings(): - if expected_calendar in {'proleptic_gregorian', 'gregorian'}: - warnings.filterwarnings( - 'ignore', 'Unable to decode time axis') + if expected_calendar in {"proleptic_gregorian", "gregorian"}: + warnings.filterwarnings("ignore", "Unable to decode time axis") with self.roundtrip(expected, save_kwargs=kwargs) as actual: abs_diff = abs(actual.t.values - expected_decoded_t) - assert (abs_diff <= np.timedelta64(1, 's')).all() - assert (actual.t.encoding['units'] - == 'days since 0001-01-01 00:00:00.000000') - assert (actual.t.encoding['calendar'] - == expected_calendar) + assert (abs_diff <= np.timedelta64(1, "s")).all() + assert ( + actual.t.encoding["units"] + == "days since 0001-01-01 00:00:00.000000" + ) + assert actual.t.encoding["calendar"] == expected_calendar abs_diff = abs(actual.t0.values - expected_decoded_t0) - assert (abs_diff <= np.timedelta64(1, 's')).all() - assert (actual.t0.encoding['units'] - == 'days since 0001-01-01') - assert (actual.t.encoding['calendar'] - == expected_calendar) + assert (abs_diff <= np.timedelta64(1, "s")).all() + assert actual.t0.encoding["units"] == "days since 0001-01-01" + assert actual.t.encoding["calendar"] == expected_calendar def test_roundtrip_timedelta_data(self): - time_deltas = pd.to_timedelta(['1h', '2h', 'NaT']) - expected = Dataset({'td': ('td', time_deltas), 'td0': time_deltas[0]}) + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) + expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) def test_roundtrip_float64_data(self): - expected = Dataset({'x': ('y', np.array([1.0, 2.0, np.pi], - dtype='float64'))}) + expected = Dataset({"x": ("y", np.array([1.0, 2.0, np.pi], dtype="float64"))}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) def test_roundtrip_example_1_netcdf(self): - with open_example_dataset('example_1.nc') as expected: + with open_example_dataset("example_1.nc") as expected: with self.roundtrip(expected) as actual: # we allow the attributes to differ since that # will depend on the encoding used. For example, @@ -464,40 +523,40 @@ def test_roundtrip_example_1_netcdf(self): assert_equal(expected, actual) def test_roundtrip_coordinates(self): - original = Dataset({'foo': ('x', [0, 1])}, - {'x': [2, 3], 'y': ('a', [42]), 'z': ('x', [4, 5])}) + original = Dataset( + {"foo": ("x", [0, 1])}, {"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])} + ) with self.roundtrip(original) as actual: assert_identical(original, actual) def test_roundtrip_global_coordinates(self): - original = Dataset({'x': [2, 3], 'y': ('a', [42]), 'z': ('x', [4, 5])}) + original = Dataset({"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])}) with self.roundtrip(original) as actual: assert_identical(original, actual) def test_roundtrip_coordinates_with_space(self): - original = Dataset(coords={'x': 0, 'y z': 1}) - expected = Dataset({'y z': 1}, {'x': 0}) + original = Dataset(coords={"x": 0, "y z": 1}) + expected = Dataset({"y z": 1}, {"x": 0}) with pytest.warns(xr.SerializationWarning): with self.roundtrip(original) as actual: assert_identical(expected, actual) def test_roundtrip_boolean_dtype(self): original = create_boolean_data() - assert original['x'].dtype == 'bool' + assert original["x"].dtype == "bool" with self.roundtrip(original) as actual: assert_identical(original, actual) - assert actual['x'].dtype == 'bool' + assert actual["x"].dtype == "bool" def test_orthogonal_indexing(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: - indexers = {'dim1': [1, 2, 0], 'dim2': [3, 2, 0, 3], - 'dim3': np.arange(5)} + indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)} expected = in_memory.isel(**indexers) actual = on_disk.isel(**indexers) # make sure the array is not yet loaded into memory - assert not actual['var1'].variable._in_memory + assert not actual["var1"].variable._in_memory assert_identical(expected, actual) # do it twice, to make sure we're switched from orthogonal -> numpy # when we cached the values @@ -507,12 +566,14 @@ def test_orthogonal_indexing(self): def test_vectorized_indexing(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: - indexers = {'dim1': DataArray([0, 2, 0], dims='a'), - 'dim2': DataArray([0, 2, 3], dims='a')} + indexers = { + "dim1": DataArray([0, 2, 0], dims="a"), + "dim2": DataArray([0, 2, 3], dims="a"), + } expected = in_memory.isel(**indexers) actual = on_disk.isel(**indexers) # make sure the array is not yet loaded into memory - assert not actual['var1'].variable._in_memory + assert not actual["var1"].variable._in_memory assert_identical(expected, actual.load()) # do it twice, to make sure we're switched from # vectorized -> numpy when we cached the values @@ -522,8 +583,8 @@ def test_vectorized_indexing(self): def multiple_indexing(indexers): # make sure a sequence of lazy indexings certainly works. with self.roundtrip(in_memory) as on_disk: - actual = on_disk['var3'] - expected = in_memory['var3'] + actual = on_disk["var3"] + expected = in_memory["var3"] for ind in indexers: actual = actual.isel(**ind) expected = expected.isel(**ind) @@ -533,47 +594,54 @@ def multiple_indexing(indexers): # two-staged vectorized-indexing indexers = [ - {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), - 'dim3': DataArray([[0, 4], [1, 3], [2, 2]], dims=['a', 'b'])}, - {'a': DataArray([0, 1], dims=['c']), - 'b': DataArray([0, 1], dims=['c'])} + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": DataArray([[0, 4], [1, 3], [2, 2]], dims=["a", "b"]), + }, + {"a": DataArray([0, 1], dims=["c"]), "b": DataArray([0, 1], dims=["c"])}, ] multiple_indexing(indexers) # vectorized-slice mixed indexers = [ - {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), - 'dim3': slice(None, 10)} + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": slice(None, 10), + } ] multiple_indexing(indexers) # vectorized-integer mixed indexers = [ - {'dim3': 0}, - {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b'])}, - {'a': slice(None, None, 2)} + {"dim3": 0}, + {"dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"])}, + {"a": slice(None, None, 2)}, ] multiple_indexing(indexers) # vectorized-integer mixed indexers = [ - {'dim3': 0}, - {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b'])}, - {'a': 1, 'b': 0} + {"dim3": 0}, + {"dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"])}, + {"a": 1, "b": 0}, ] multiple_indexing(indexers) # with negative step slice. indexers = [ - {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), - 'dim3': slice(-1, 1, -1)}, + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": slice(-1, 1, -1), + } ] multiple_indexing(indexers) # with negative step slice. indexers = [ - {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), - 'dim3': slice(-1, 1, -2)}, + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": slice(-1, 1, -2), + } ] multiple_indexing(indexers) @@ -581,15 +649,15 @@ def test_isel_dataarray(self): # Make sure isel works lazily. GH:issue:1688 in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: - expected = in_memory.isel(dim2=in_memory['dim2'] < 3) - actual = on_disk.isel(dim2=on_disk['dim2'] < 3) + expected = in_memory.isel(dim2=in_memory["dim2"] < 3) + actual = on_disk.isel(dim2=on_disk["dim2"] < 3) assert_identical(expected, actual) def validate_array_type(self, ds): # Make sure that only NumpyIndexingAdapter stores a bare np.ndarray. def find_and_validate_array(obj): # recursively called function. obj: array or array wrapper. - if hasattr(obj, 'array'): + if hasattr(obj, "array"): if isinstance(obj.array, indexing.ExplicitlyIndexed): find_and_validate_array(obj.array) else: @@ -600,8 +668,9 @@ def find_and_validate_array(obj): elif isinstance(obj.array, pd.Index): assert isinstance(obj, indexing.PandasIndexAdapter) else: - raise TypeError('{} is wrapped by {}'.format( - type(obj.array), type(obj))) + raise TypeError( + "{} is wrapped by {}".format(type(obj.array), type(obj)) + ) for k, v in ds.variables.items(): find_and_validate_array(v._data) @@ -610,8 +679,7 @@ def test_array_type_after_indexing(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: self.validate_array_type(on_disk) - indexers = {'dim1': [1, 2, 0], 'dim2': [3, 2, 0, 3], - 'dim3': np.arange(5)} + indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)} expected = in_memory.isel(**indexers) actual = on_disk.isel(**indexers) assert_identical(expected, actual) @@ -626,16 +694,18 @@ def test_dropna(self): # regression test for GH:issue:1694 a = np.random.randn(4, 3) a[1, 1] = np.NaN - in_memory = xr.Dataset({'a': (('y', 'x'), a)}, - coords={'y': np.arange(4), 'x': np.arange(3)}) + in_memory = xr.Dataset( + {"a": (("y", "x"), a)}, coords={"y": np.arange(4), "x": np.arange(3)} + ) - assert_identical(in_memory.dropna(dim='x'), - in_memory.isel(x=slice(None, None, 2))) + assert_identical( + in_memory.dropna(dim="x"), in_memory.isel(x=slice(None, None, 2)) + ) with self.roundtrip(in_memory) as on_disk: self.validate_array_type(on_disk) - expected = in_memory.dropna(dim='x') - actual = on_disk.dropna(dim='x') + expected = in_memory.dropna(dim="x") + actual = on_disk.dropna(dim="x") assert_identical(expected, actual) def test_ondisk_after_print(self): @@ -643,70 +713,71 @@ def test_ondisk_after_print(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: repr(on_disk) - assert not on_disk['var1']._in_memory + assert not on_disk["var1"]._in_memory class CFEncodedBase(DatasetIOBase): - def test_roundtrip_bytes_with_fill_value(self): - values = np.array([b'ab', b'cdef', np.nan], dtype=object) - encoding = {'_FillValue': b'X', 'dtype': 'S1'} - original = Dataset({'x': ('t', values, {}, encoding)}) + values = np.array([b"ab", b"cdef", np.nan], dtype=object) + encoding = {"_FillValue": b"X", "dtype": "S1"} + original = Dataset({"x": ("t", values, {}, encoding)}) expected = original.copy(deep=True) with self.roundtrip(original) as actual: assert_identical(expected, actual) - original = Dataset({'x': ('t', values, {}, {'_FillValue': b''})}) + original = Dataset({"x": ("t", values, {}, {"_FillValue": b""})}) with self.roundtrip(original) as actual: assert_identical(expected, actual) def test_roundtrip_string_with_fill_value_nchar(self): - values = np.array(['ab', 'cdef', np.nan], dtype=object) - expected = Dataset({'x': ('t', values)}) + values = np.array(["ab", "cdef", np.nan], dtype=object) + expected = Dataset({"x": ("t", values)}) - encoding = {'dtype': 'S1', '_FillValue': b'X'} - original = Dataset({'x': ('t', values, {}, encoding)}) + encoding = {"dtype": "S1", "_FillValue": b"X"} + original = Dataset({"x": ("t", values, {}, encoding)}) # Not supported yet. with pytest.raises(NotImplementedError): with self.roundtrip(original) as actual: assert_identical(expected, actual) @pytest.mark.parametrize( - 'decoded_fn, encoded_fn', - [(create_unsigned_masked_scaled_data, - create_encoded_unsigned_masked_scaled_data), - pytest.param(create_bad_unsigned_masked_scaled_data, - create_bad_encoded_unsigned_masked_scaled_data, - marks=pytest.mark.xfail( - reason="Bad _Unsigned attribute.")), - (create_signed_masked_scaled_data, - create_encoded_signed_masked_scaled_data), - (create_masked_and_scaled_data, - create_encoded_masked_and_scaled_data)]) + "decoded_fn, encoded_fn", + [ + ( + create_unsigned_masked_scaled_data, + create_encoded_unsigned_masked_scaled_data, + ), + pytest.param( + create_bad_unsigned_masked_scaled_data, + create_bad_encoded_unsigned_masked_scaled_data, + marks=pytest.mark.xfail(reason="Bad _Unsigned attribute."), + ), + ( + create_signed_masked_scaled_data, + create_encoded_signed_masked_scaled_data, + ), + (create_masked_and_scaled_data, create_encoded_masked_and_scaled_data), + ], + ) def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn): decoded = decoded_fn() encoded = encoded_fn() with self.roundtrip(decoded) as actual: for k in decoded.variables: - assert (decoded.variables[k].dtype - == actual.variables[k].dtype) + assert decoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) - with self.roundtrip(decoded, - open_kwargs=dict(decode_cf=False)) as actual: + with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: # TODO: this assumes that all roundtrips will first # encode. Is that something we want to test for? for k in encoded.variables: - assert (encoded.variables[k].dtype - == actual.variables[k].dtype) + assert encoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(encoded, actual, decode_bytes=False) - with self.roundtrip(encoded, - open_kwargs=dict(decode_cf=False)) as actual: + with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - assert (encoded.variables[k].dtype - == actual.variables[k].dtype) + assert encoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(encoded, actual, decode_bytes=False) # make sure roundtrip encoding didn't change the @@ -715,42 +786,47 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn): with self.roundtrip(encoded) as actual: for k in decoded.variables: - assert (decoded.variables[k].dtype == - actual.variables[k].dtype) + assert decoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) def test_coordinates_encoding(self): def equals_latlon(obj): - return obj == 'lat lon' or obj == 'lon lat' + return obj == "lat lon" or obj == "lon lat" - original = Dataset({'temp': ('x', [0, 1]), 'precip': ('x', [0, -1])}, - {'lat': ('x', [2, 3]), 'lon': ('x', [4, 5])}) + original = Dataset( + {"temp": ("x", [0, 1]), "precip": ("x", [0, -1])}, + {"lat": ("x", [2, 3]), "lon": ("x", [4, 5])}, + ) with self.roundtrip(original) as actual: assert_identical(actual, original) with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - assert equals_latlon(ds['temp'].attrs['coordinates']) - assert equals_latlon(ds['precip'].attrs['coordinates']) - assert 'coordinates' not in ds.attrs - assert 'coordinates' not in ds['lat'].attrs - assert 'coordinates' not in ds['lon'].attrs + assert equals_latlon(ds["temp"].attrs["coordinates"]) + assert equals_latlon(ds["precip"].attrs["coordinates"]) + assert "coordinates" not in ds.attrs + assert "coordinates" not in ds["lat"].attrs + assert "coordinates" not in ds["lon"].attrs - modified = original.drop(['temp', 'precip']) + modified = original.drop(["temp", "precip"]) with self.roundtrip(modified) as actual: assert_identical(actual, modified) with create_tmp_file() as tmp_file: modified.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - assert equals_latlon(ds.attrs['coordinates']) - assert 'coordinates' not in ds['lat'].attrs - assert 'coordinates' not in ds['lon'].attrs + assert equals_latlon(ds.attrs["coordinates"]) + assert "coordinates" not in ds["lat"].attrs + assert "coordinates" not in ds["lon"].attrs def test_roundtrip_endian(self): - ds = Dataset({'x': np.arange(3, 10, dtype='>i2'), - 'y': np.arange(3, 20, dtype=' 1) == (f1a is f1b) + f1b.write("baz") + assert (getattr(file_cache, "maxsize", float("inf")) > 1) == (f1a is f1b) manager1.close() manager2.close() - with open(path1, 'r') as f: - assert f.read() == 'foobaz' - with open(path2, 'r') as f: - assert f.read() == 'bar' + with open(path1, "r") as f: + assert f.read() == "foobaz" + with open(path2, "r") as f: + assert f.read() == "bar" def test_file_manager_write_concurrent(tmpdir, file_cache): - path = str(tmpdir.join('testing.txt')) - manager = CachingFileManager(open, path, mode='w', cache=file_cache) + path = str(tmpdir.join("testing.txt")) + manager = CachingFileManager(open, path, mode="w", cache=file_cache) f1 = manager.acquire() f2 = manager.acquire() f3 = manager.acquire() assert f1 is f2 assert f2 is f3 - f1.write('foo') + f1.write("foo") f1.flush() - f2.write('bar') + f2.write("bar") f2.flush() - f3.write('baz') + f3.write("baz") f3.flush() manager.close() - with open(path, 'r') as f: - assert f.read() == 'foobarbaz' + with open(path, "r") as f: + assert f.read() == "foobarbaz" def test_file_manager_write_pickle(tmpdir, file_cache): - path = str(tmpdir.join('testing.txt')) - manager = CachingFileManager(open, path, mode='w', cache=file_cache) + path = str(tmpdir.join("testing.txt")) + manager = CachingFileManager(open, path, mode="w", cache=file_cache) f = manager.acquire() - f.write('foo') + f.write("foo") f.flush() manager2 = pickle.loads(pickle.dumps(manager)) f2 = manager2.acquire() - f2.write('bar') + f2.write("bar") manager2.close() manager.close() - with open(path, 'r') as f: - assert f.read() == 'foobar' + with open(path, "r") as f: + assert f.read() == "foobar" def test_file_manager_read(tmpdir, file_cache): - path = str(tmpdir.join('testing.txt')) + path = str(tmpdir.join("testing.txt")) - with open(path, 'w') as f: - f.write('foobar') + with open(path, "w") as f: + f.write("foobar") manager = CachingFileManager(open, path, cache=file_cache) f = manager.acquire() - assert f.read() == 'foobar' + assert f.read() == "foobar" manager.close() def test_file_manager_invalid_kwargs(): with pytest.raises(TypeError): - CachingFileManager(open, 'dummy', mode='w', invalid=True) + CachingFileManager(open, "dummy", mode="w", invalid=True) def test_file_manager_acquire_context(tmpdir, file_cache): - path = str(tmpdir.join('testing.txt')) + path = str(tmpdir.join("testing.txt")) - with open(path, 'w') as f: - f.write('foobar') + with open(path, "w") as f: + f.write("foobar") class AcquisitionError(Exception): pass @@ -222,17 +219,17 @@ class AcquisitionError(Exception): manager = CachingFileManager(open, path, cache=file_cache) with pytest.raises(AcquisitionError): with manager.acquire_context() as f: - assert f.read() == 'foobar' + assert f.read() == "foobar" raise AcquisitionError assert not file_cache # file was *not* already open with manager.acquire_context() as f: - assert f.read() == 'foobar' + assert f.read() == "foobar" with pytest.raises(AcquisitionError): with manager.acquire_context() as f: f.seek(0) - assert f.read() == 'foobar' + assert f.read() == "foobar" raise AcquisitionError assert file_cache # file *was* already open diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py index 5f83321802e..f7e48b65d46 100644 --- a/xarray/tests/test_backends_locks.py +++ b/xarray/tests/test_backends_locks.py @@ -4,10 +4,10 @@ def test_threaded_lock(): - lock1 = locks._get_threaded_lock('foo') + lock1 = locks._get_threaded_lock("foo") assert isinstance(lock1, type(threading.Lock())) - lock2 = locks._get_threaded_lock('foo') + lock2 = locks._get_threaded_lock("foo") assert lock1 is lock2 - lock3 = locks._get_threaded_lock('bar') + lock3 = locks._get_threaded_lock("bar") assert lock1 is not lock3 diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py index aa97f5fb4cb..2aaa8c9e631 100644 --- a/xarray/tests/test_backends_lru_cache.py +++ b/xarray/tests/test_backends_lru_cache.py @@ -7,24 +7,24 @@ def test_simple(): cache = LRUCache(maxsize=2) - cache['x'] = 1 - cache['y'] = 2 + cache["x"] = 1 + cache["y"] = 2 - assert cache['x'] == 1 - assert cache['y'] == 2 + assert cache["x"] == 1 + assert cache["y"] == 2 assert len(cache) == 2 - assert dict(cache) == {'x': 1, 'y': 2} - assert list(cache.keys()) == ['x', 'y'] - assert list(cache.items()) == [('x', 1), ('y', 2)] + assert dict(cache) == {"x": 1, "y": 2} + assert list(cache.keys()) == ["x", "y"] + assert list(cache.items()) == [("x", 1), ("y", 2)] - cache['z'] = 3 + cache["z"] = 3 assert len(cache) == 2 - assert list(cache.items()) == [('y', 2), ('z', 3)] + assert list(cache.items()) == [("y", 2), ("z", 3)] def test_trivial(): cache = LRUCache(maxsize=0) - cache['x'] = 1 + cache["x"] = 1 assert len(cache) == 0 @@ -37,52 +37,52 @@ def test_invalid(): def test_update_priority(): cache = LRUCache(maxsize=2) - cache['x'] = 1 - cache['y'] = 2 - assert list(cache) == ['x', 'y'] - assert 'x' in cache # contains - assert list(cache) == ['y', 'x'] - assert cache['y'] == 2 # getitem - assert list(cache) == ['x', 'y'] - cache['x'] = 3 # setitem - assert list(cache.items()) == [('y', 2), ('x', 3)] + cache["x"] = 1 + cache["y"] = 2 + assert list(cache) == ["x", "y"] + assert "x" in cache # contains + assert list(cache) == ["y", "x"] + assert cache["y"] == 2 # getitem + assert list(cache) == ["x", "y"] + cache["x"] = 3 # setitem + assert list(cache.items()) == [("y", 2), ("x", 3)] def test_del(): cache = LRUCache(maxsize=2) - cache['x'] = 1 - cache['y'] = 2 - del cache['x'] - assert dict(cache) == {'y': 2} + cache["x"] = 1 + cache["y"] = 2 + del cache["x"] + assert dict(cache) == {"y": 2} def test_on_evict(): on_evict = mock.Mock() cache = LRUCache(maxsize=1, on_evict=on_evict) - cache['x'] = 1 - cache['y'] = 2 - on_evict.assert_called_once_with('x', 1) + cache["x"] = 1 + cache["y"] = 2 + on_evict.assert_called_once_with("x", 1) def test_on_evict_trivial(): on_evict = mock.Mock() cache = LRUCache(maxsize=0, on_evict=on_evict) - cache['x'] = 1 - on_evict.assert_called_once_with('x', 1) + cache["x"] = 1 + on_evict.assert_called_once_with("x", 1) def test_resize(): cache = LRUCache(maxsize=2) assert cache.maxsize == 2 - cache['w'] = 0 - cache['x'] = 1 - cache['y'] = 2 - assert list(cache.items()) == [('x', 1), ('y', 2)] + cache["w"] = 0 + cache["x"] = 1 + cache["y"] = 2 + assert list(cache.items()) == [("x", 1), ("y", 2)] cache.maxsize = 10 - cache['z'] = 3 - assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache["z"] = 3 + assert list(cache.items()) == [("x", 1), ("y", 2), ("z", 3)] cache.maxsize = 1 - assert list(cache.items()) == [('z', 3)] + assert list(cache.items()) == [("z", 3)] with pytest.raises(ValueError): cache.maxsize = -1 diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index b3560fe3039..3be46b68fc4 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -6,15 +6,38 @@ from xarray import CFTimeIndex from xarray.coding.cftime_offsets import ( - _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin, - MonthEnd, QuarterBegin, QuarterEnd, Second, YearBegin, YearEnd, - _days_in_month, cftime_range, get_date_type, to_cftime_datetime, to_offset) + _MONTH_ABBREVIATIONS, + BaseCFTimeOffset, + Day, + Hour, + Minute, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + Second, + YearBegin, + YearEnd, + _days_in_month, + cftime_range, + get_date_type, + to_cftime_datetime, + to_offset, +) -cftime = pytest.importorskip('cftime') +cftime = pytest.importorskip("cftime") -_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', - '366_day', 'gregorian', 'proleptic_gregorian', 'standard'] +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", + "standard", +] def _id_func(param): @@ -28,31 +51,35 @@ def calendar(request): @pytest.mark.parametrize( - ('offset', 'expected_n'), - [(BaseCFTimeOffset(), 1), - (YearBegin(), 1), - (YearEnd(), 1), - (QuarterBegin(), 1), - (QuarterEnd(), 1), - (BaseCFTimeOffset(n=2), 2), - (YearBegin(n=2), 2), - (YearEnd(n=2), 2), - (QuarterBegin(n=2), 2), - (QuarterEnd(n=2), 2)], - ids=_id_func + ("offset", "expected_n"), + [ + (BaseCFTimeOffset(), 1), + (YearBegin(), 1), + (YearEnd(), 1), + (QuarterBegin(), 1), + (QuarterEnd(), 1), + (BaseCFTimeOffset(n=2), 2), + (YearBegin(n=2), 2), + (YearEnd(n=2), 2), + (QuarterBegin(n=2), 2), + (QuarterEnd(n=2), 2), + ], + ids=_id_func, ) def test_cftime_offset_constructor_valid_n(offset, expected_n): assert offset.n == expected_n @pytest.mark.parametrize( - ('offset', 'invalid_n'), - [(BaseCFTimeOffset, 1.5), - (YearBegin, 1.5), - (YearEnd, 1.5), - (QuarterBegin, 1.5), - (QuarterEnd, 1.5)], - ids=_id_func + ("offset", "invalid_n"), + [ + (BaseCFTimeOffset, 1.5), + (YearBegin, 1.5), + (YearEnd, 1.5), + (QuarterBegin, 1.5), + (QuarterEnd, 1.5), + ], + ids=_id_func, ) def test_cftime_offset_constructor_invalid_n(offset, invalid_n): with pytest.raises(TypeError): @@ -60,61 +87,68 @@ def test_cftime_offset_constructor_invalid_n(offset, invalid_n): @pytest.mark.parametrize( - ('offset', 'expected_month'), - [(YearBegin(), 1), - (YearEnd(), 12), - (YearBegin(month=5), 5), - (YearEnd(month=5), 5), - (QuarterBegin(), 3), - (QuarterEnd(), 3), - (QuarterBegin(month=5), 5), - (QuarterEnd(month=5), 5)], - ids=_id_func + ("offset", "expected_month"), + [ + (YearBegin(), 1), + (YearEnd(), 12), + (YearBegin(month=5), 5), + (YearEnd(month=5), 5), + (QuarterBegin(), 3), + (QuarterEnd(), 3), + (QuarterBegin(month=5), 5), + (QuarterEnd(month=5), 5), + ], + ids=_id_func, ) def test_year_offset_constructor_valid_month(offset, expected_month): assert offset.month == expected_month @pytest.mark.parametrize( - ('offset', 'invalid_month', 'exception'), - [(YearBegin, 0, ValueError), - (YearEnd, 0, ValueError), - (YearBegin, 13, ValueError,), - (YearEnd, 13, ValueError), - (YearBegin, 1.5, TypeError), - (YearEnd, 1.5, TypeError), - (QuarterBegin, 0, ValueError), - (QuarterEnd, 0, ValueError), - (QuarterBegin, 1.5, TypeError), - (QuarterEnd, 1.5, TypeError), - (QuarterBegin, 13, ValueError), - (QuarterEnd, 13, ValueError)], - ids=_id_func -) -def test_year_offset_constructor_invalid_month( - offset, invalid_month, exception): + ("offset", "invalid_month", "exception"), + [ + (YearBegin, 0, ValueError), + (YearEnd, 0, ValueError), + (YearBegin, 13, ValueError), + (YearEnd, 13, ValueError), + (YearBegin, 1.5, TypeError), + (YearEnd, 1.5, TypeError), + (QuarterBegin, 0, ValueError), + (QuarterEnd, 0, ValueError), + (QuarterBegin, 1.5, TypeError), + (QuarterEnd, 1.5, TypeError), + (QuarterBegin, 13, ValueError), + (QuarterEnd, 13, ValueError), + ], + ids=_id_func, +) +def test_year_offset_constructor_invalid_month(offset, invalid_month, exception): with pytest.raises(exception): offset(month=invalid_month) @pytest.mark.parametrize( - ('offset', 'expected'), - [(BaseCFTimeOffset(), None), - (MonthBegin(), 'MS'), - (YearBegin(), 'AS-JAN'), - (QuarterBegin(), 'QS-MAR')], - ids=_id_func + ("offset", "expected"), + [ + (BaseCFTimeOffset(), None), + (MonthBegin(), "MS"), + (YearBegin(), "AS-JAN"), + (QuarterBegin(), "QS-MAR"), + ], + ids=_id_func, ) def test_rule_code(offset, expected): assert offset.rule_code() == expected @pytest.mark.parametrize( - ('offset', 'expected'), - [(BaseCFTimeOffset(), ''), - (YearBegin(), ''), - (QuarterBegin(), '')], - ids=_id_func + ("offset", "expected"), + [ + (BaseCFTimeOffset(), ""), + (YearBegin(), ""), + (QuarterBegin(), ""), + ], + ids=_id_func, ) def test_str_and_repr(offset, expected): assert str(offset) == expected @@ -122,53 +156,53 @@ def test_str_and_repr(offset, expected): @pytest.mark.parametrize( - 'offset', + "offset", [BaseCFTimeOffset(), MonthBegin(), QuarterBegin(), YearBegin()], - ids=_id_func + ids=_id_func, ) def test_to_offset_offset_input(offset): assert to_offset(offset) == offset @pytest.mark.parametrize( - ('freq', 'expected'), - [('M', MonthEnd()), - ('2M', MonthEnd(n=2)), - ('MS', MonthBegin()), - ('2MS', MonthBegin(n=2)), - ('D', Day()), - ('2D', Day(n=2)), - ('H', Hour()), - ('2H', Hour(n=2)), - ('T', Minute()), - ('2T', Minute(n=2)), - ('min', Minute()), - ('2min', Minute(n=2)), - ('S', Second()), - ('2S', Second(n=2))], - ids=_id_func + ("freq", "expected"), + [ + ("M", MonthEnd()), + ("2M", MonthEnd(n=2)), + ("MS", MonthBegin()), + ("2MS", MonthBegin(n=2)), + ("D", Day()), + ("2D", Day(n=2)), + ("H", Hour()), + ("2H", Hour(n=2)), + ("T", Minute()), + ("2T", Minute(n=2)), + ("min", Minute()), + ("2min", Minute(n=2)), + ("S", Second()), + ("2S", Second(n=2)), + ], + ids=_id_func, ) def test_to_offset_sub_annual(freq, expected): assert to_offset(freq) == expected -_ANNUAL_OFFSET_TYPES = { - 'A': YearEnd, - 'AS': YearBegin -} +_ANNUAL_OFFSET_TYPES = {"A": YearEnd, "AS": YearBegin} -@pytest.mark.parametrize(('month_int', 'month_label'), - list(_MONTH_ABBREVIATIONS.items()) + [(0, '')]) -@pytest.mark.parametrize('multiple', [None, 2]) -@pytest.mark.parametrize('offset_str', ['AS', 'A']) +@pytest.mark.parametrize( + ("month_int", "month_label"), list(_MONTH_ABBREVIATIONS.items()) + [(0, "")] +) +@pytest.mark.parametrize("multiple", [None, 2]) +@pytest.mark.parametrize("offset_str", ["AS", "A"]) def test_to_offset_annual(month_label, month_int, multiple, offset_str): freq = offset_str offset_type = _ANNUAL_OFFSET_TYPES[offset_str] if month_label: - freq = '-'.join([freq, month_label]) + freq = "-".join([freq, month_label]) if multiple: - freq = '{}'.format(multiple) + freq + freq = "{}".format(multiple) + freq result = to_offset(freq) if multiple and month_int: @@ -182,23 +216,21 @@ def test_to_offset_annual(month_label, month_int, multiple, offset_str): assert result == expected -_QUARTER_OFFSET_TYPES = { - 'Q': QuarterEnd, - 'QS': QuarterBegin -} +_QUARTER_OFFSET_TYPES = {"Q": QuarterEnd, "QS": QuarterBegin} -@pytest.mark.parametrize(('month_int', 'month_label'), - list(_MONTH_ABBREVIATIONS.items()) + [(0, '')]) -@pytest.mark.parametrize('multiple', [None, 2]) -@pytest.mark.parametrize('offset_str', ['QS', 'Q']) +@pytest.mark.parametrize( + ("month_int", "month_label"), list(_MONTH_ABBREVIATIONS.items()) + [(0, "")] +) +@pytest.mark.parametrize("multiple", [None, 2]) +@pytest.mark.parametrize("offset_str", ["QS", "Q"]) def test_to_offset_quarter(month_label, month_int, multiple, offset_str): freq = offset_str offset_type = _QUARTER_OFFSET_TYPES[offset_str] if month_label: - freq = '-'.join([freq, month_label]) + freq = "-".join([freq, month_label]) if multiple: - freq = '{}'.format(multiple) + freq + freq = "{}".format(multiple) + freq result = to_offset(freq) if multiple and month_int: @@ -221,18 +253,16 @@ def test_to_offset_quarter(month_label, month_int, multiple, offset_str): assert result == expected -@pytest.mark.parametrize('freq', ['Z', '7min2', 'AM', 'M-', 'AS-', 'QS-', - '1H1min']) +@pytest.mark.parametrize("freq", ["Z", "7min2", "AM", "M-", "AS-", "QS-", "1H1min"]) def test_invalid_to_offset_str(freq): with pytest.raises(ValueError): to_offset(freq) @pytest.mark.parametrize( - ('argument', 'expected_date_args'), - [('2000-01-01', (2000, 1, 1)), - ((2000, 1, 1), (2000, 1, 1))], - ids=_id_func + ("argument", "expected_date_args"), + [("2000-01-01", (2000, 1, 1)), ((2000, 1, 1), (2000, 1, 1))], + ids=_id_func, ) def test_to_cftime_datetime(calendar, argument, expected_date_args): date_type = get_date_type(calendar) @@ -245,7 +275,7 @@ def test_to_cftime_datetime(calendar, argument, expected_date_args): def test_to_cftime_datetime_error_no_calendar(): with pytest.raises(ValueError): - to_cftime_datetime('2000') + to_cftime_datetime("2000") def test_to_cftime_datetime_error_type_error(): @@ -254,39 +284,66 @@ def test_to_cftime_datetime_error_type_error(): _EQ_TESTS_A = [ - BaseCFTimeOffset(), YearBegin(), YearEnd(), YearBegin(month=2), - YearEnd(month=2), QuarterBegin(), QuarterEnd(), QuarterBegin(month=2), - QuarterEnd(month=2), MonthBegin(), MonthEnd(), Day(), Hour(), Minute(), - Second() + BaseCFTimeOffset(), + YearBegin(), + YearEnd(), + YearBegin(month=2), + YearEnd(month=2), + QuarterBegin(), + QuarterEnd(), + QuarterBegin(month=2), + QuarterEnd(month=2), + MonthBegin(), + MonthEnd(), + Day(), + Hour(), + Minute(), + Second(), ] _EQ_TESTS_B = [ - BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), - YearBegin(n=2, month=2), YearEnd(n=2, month=2), QuarterBegin(n=2), - QuarterEnd(n=2), QuarterBegin(n=2, month=2), QuarterEnd(n=2, month=2), - MonthBegin(n=2), MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), - Second(n=2) + BaseCFTimeOffset(n=2), + YearBegin(n=2), + YearEnd(n=2), + YearBegin(n=2, month=2), + YearEnd(n=2, month=2), + QuarterBegin(n=2), + QuarterEnd(n=2), + QuarterBegin(n=2, month=2), + QuarterEnd(n=2, month=2), + MonthBegin(n=2), + MonthEnd(n=2), + Day(n=2), + Hour(n=2), + Minute(n=2), + Second(n=2), ] -@pytest.mark.parametrize( - ('a', 'b'), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func -) +@pytest.mark.parametrize(("a", "b"), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func) def test_neq(a, b): assert a != b _EQ_TESTS_B_COPY = [ - BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), - YearBegin(n=2, month=2), YearEnd(n=2, month=2), QuarterBegin(n=2), - QuarterEnd(n=2), QuarterBegin(n=2, month=2), QuarterEnd(n=2, month=2), - MonthBegin(n=2), MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), - Second(n=2) + BaseCFTimeOffset(n=2), + YearBegin(n=2), + YearEnd(n=2), + YearBegin(n=2, month=2), + YearEnd(n=2, month=2), + QuarterBegin(n=2), + QuarterEnd(n=2), + QuarterBegin(n=2, month=2), + QuarterEnd(n=2, month=2), + MonthBegin(n=2), + MonthEnd(n=2), + Day(n=2), + Hour(n=2), + Minute(n=2), + Second(n=2), ] -@pytest.mark.parametrize( - ('a', 'b'), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func -) +@pytest.mark.parametrize(("a", "b"), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func) def test_eq(a, b): assert a == b @@ -302,34 +359,37 @@ def test_eq(a, b): (Day(), Day(n=3)), (Hour(), Hour(n=3)), (Minute(), Minute(n=3)), - (Second(), Second(n=3)) + (Second(), Second(n=3)), ] -@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +@pytest.mark.parametrize(("offset", "expected"), _MUL_TESTS, ids=_id_func) def test_mul(offset, expected): assert offset * 3 == expected -@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +@pytest.mark.parametrize(("offset", "expected"), _MUL_TESTS, ids=_id_func) def test_rmul(offset, expected): assert 3 * offset == expected @pytest.mark.parametrize( - ('offset', 'expected'), - [(BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), - (YearEnd(), YearEnd(n=-1)), - (YearBegin(), YearBegin(n=-1)), - (QuarterEnd(), QuarterEnd(n=-1)), - (QuarterBegin(), QuarterBegin(n=-1)), - (MonthEnd(), MonthEnd(n=-1)), - (MonthBegin(), MonthBegin(n=-1)), - (Day(), Day(n=-1)), - (Hour(), Hour(n=-1)), - (Minute(), Minute(n=-1)), - (Second(), Second(n=-1))], - ids=_id_func) + ("offset", "expected"), + [ + (BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), + (YearEnd(), YearEnd(n=-1)), + (YearBegin(), YearBegin(n=-1)), + (QuarterEnd(), QuarterEnd(n=-1)), + (QuarterBegin(), QuarterBegin(n=-1)), + (MonthEnd(), MonthEnd(n=-1)), + (MonthBegin(), MonthBegin(n=-1)), + (Day(), Day(n=-1)), + (Hour(), Hour(n=-1)), + (Minute(), Minute(n=-1)), + (Second(), Second(n=-1)), + ], + ids=_id_func, +) def test_neg(offset, expected): assert -offset == expected @@ -338,15 +398,11 @@ def test_neg(offset, expected): (Day(n=2), (1, 1, 3)), (Hour(n=2), (1, 1, 1, 2)), (Minute(n=2), (1, 1, 1, 0, 2)), - (Second(n=2), (1, 1, 1, 0, 0, 2)) + (Second(n=2), (1, 1, 1, 0, 0, 2)), ] -@pytest.mark.parametrize( - ('offset', 'expected_date_args'), - _ADD_TESTS, - ids=_id_func -) +@pytest.mark.parametrize(("offset", "expected_date_args"), _ADD_TESTS, ids=_id_func) def test_add_sub_monthly(offset, expected_date_args, calendar): date_type = get_date_type(calendar) initial = date_type(1, 1, 1) @@ -355,11 +411,7 @@ def test_add_sub_monthly(offset, expected_date_args, calendar): assert result == expected -@pytest.mark.parametrize( - ('offset', 'expected_date_args'), - _ADD_TESTS, - ids=_id_func -) +@pytest.mark.parametrize(("offset", "expected_date_args"), _ADD_TESTS, ids=_id_func) def test_radd_sub_monthly(offset, expected_date_args, calendar): date_type = get_date_type(calendar) initial = date_type(1, 1, 1) @@ -369,12 +421,14 @@ def test_radd_sub_monthly(offset, expected_date_args, calendar): @pytest.mark.parametrize( - ('offset', 'expected_date_args'), - [(Day(n=2), (1, 1, 1)), - (Hour(n=2), (1, 1, 2, 22)), - (Minute(n=2), (1, 1, 2, 23, 58)), - (Second(n=2), (1, 1, 2, 23, 59, 58))], - ids=_id_func + ("offset", "expected_date_args"), + [ + (Day(n=2), (1, 1, 1)), + (Hour(n=2), (1, 1, 2, 22)), + (Minute(n=2), (1, 1, 2, 23, 58)), + (Second(n=2), (1, 1, 2, 23, 59, 58)), + ], + ids=_id_func, ) def test_rsub_sub_monthly(offset, expected_date_args, calendar): date_type = get_date_type(calendar) @@ -384,7 +438,7 @@ def test_rsub_sub_monthly(offset, expected_date_args, calendar): assert result == expected -@pytest.mark.parametrize('offset', _EQ_TESTS_A, ids=_id_func) +@pytest.mark.parametrize("offset", _EQ_TESTS_A, ids=_id_func) def test_sub_error(offset, calendar): date_type = get_date_type(calendar) initial = date_type(1, 1, 1) @@ -392,11 +446,7 @@ def test_sub_error(offset, calendar): offset - initial -@pytest.mark.parametrize( - ('a', 'b'), - zip(_EQ_TESTS_A, _EQ_TESTS_B), - ids=_id_func -) +@pytest.mark.parametrize(("a", "b"), zip(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func) def test_minus_offset(a, b): result = b - a expected = a @@ -404,10 +454,10 @@ def test_minus_offset(a, b): @pytest.mark.parametrize( - ('a', 'b'), - list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + - [(YearEnd(month=1), YearEnd(month=2))], - ids=_id_func + ("a", "b"), + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + + [(YearEnd(month=1), YearEnd(month=2))], + ids=_id_func, ) def test_minus_offset_error(a, b): with pytest.raises(TypeError): @@ -421,7 +471,7 @@ def test_days_in_month_non_december(calendar): def test_days_in_month_december(calendar): - if calendar == '360_day': + if calendar == "360_day": expected = 30 else: expected = 31 @@ -431,24 +481,25 @@ def test_days_in_month_december(calendar): @pytest.mark.parametrize( - ('initial_date_args', 'offset', 'expected_date_args'), - [((1, 1, 1), MonthBegin(), (1, 2, 1)), - ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), - ((1, 1, 7), MonthBegin(), (1, 2, 1)), - ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), - ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), - ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), - ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), - ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), - ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), - ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), - ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), - ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), - ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], - ids=_id_func -) -def test_add_month_begin( - calendar, initial_date_args, offset, expected_date_args): + ("initial_date_args", "offset", "expected_date_args"), + [ + ((1, 1, 1), MonthBegin(), (1, 2, 1)), + ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), + ((1, 1, 7), MonthBegin(), (1, 2, 1)), + ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), + ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), + ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), + ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), + ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), + ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), + ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), + ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_month_begin(calendar, initial_date_args, offset, expected_date_args): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) result = initial + offset @@ -457,21 +508,21 @@ def test_add_month_begin( @pytest.mark.parametrize( - ('initial_date_args', 'offset', 'expected_year_month', - 'expected_sub_day'), - [((1, 1, 1), MonthEnd(), (1, 1), ()), - ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), - ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), - ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), - ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), - ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), - ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), - ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], - ids=_id_func + ("initial_date_args", "offset", "expected_year_month", "expected_sub_day"), + [ + ((1, 1, 1), MonthEnd(), (1, 1), ()), + ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), + ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), + ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), + ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5)), + ], + ids=_id_func, ) def test_add_month_end( - calendar, initial_date_args, offset, expected_year_month, - expected_sub_day + calendar, initial_date_args, offset, expected_year_month, expected_sub_day ): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) @@ -480,60 +531,75 @@ def test_add_month_end( reference = date_type(*reference_args) # Here the days at the end of each month varies based on the calendar used - expected_date_args = (expected_year_month + - (_days_in_month(reference),) + expected_sub_day) + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) expected = date_type(*expected_date_args) assert result == expected @pytest.mark.parametrize( - ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', - 'expected_sub_day'), - [((1, 1), (), MonthEnd(), (1, 2), ()), - ((1, 1), (), MonthEnd(n=2), (1, 3), ()), - ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), - ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), - ((1, 2), (), MonthEnd(n=14), (2, 4), ()), - ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), - ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), - ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], - ids=_id_func + ( + "initial_year_month", + "initial_sub_day", + "offset", + "expected_year_month", + "expected_sub_day", + ), + [ + ((1, 1), (), MonthEnd(), (1, 2), ()), + ((1, 1), (), MonthEnd(n=2), (1, 3), ()), + ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), + ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), + ((1, 2), (), MonthEnd(n=14), (2, 4), ()), + ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), + ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), + ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5)), + ], + ids=_id_func, ) def test_add_month_end_onOffset( - calendar, initial_year_month, initial_sub_day, offset, expected_year_month, - expected_sub_day + calendar, + initial_year_month, + initial_sub_day, + offset, + expected_year_month, + expected_sub_day, ): date_type = get_date_type(calendar) reference_args = initial_year_month + (1,) reference = date_type(*reference_args) - initial_date_args = (initial_year_month + (_days_in_month(reference),) + - initial_sub_day) + initial_date_args = ( + initial_year_month + (_days_in_month(reference),) + initial_sub_day + ) initial = date_type(*initial_date_args) result = initial + offset reference_args = expected_year_month + (1,) reference = date_type(*reference_args) # Here the days at the end of each month varies based on the calendar used - expected_date_args = (expected_year_month + - (_days_in_month(reference),) + expected_sub_day) + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) expected = date_type(*expected_date_args) assert result == expected @pytest.mark.parametrize( - ('initial_date_args', 'offset', 'expected_date_args'), - [((1, 1, 1), YearBegin(), (2, 1, 1)), - ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), - ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), - ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), - ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), - ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), - ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), - ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], - ids=_id_func -) -def test_add_year_begin(calendar, initial_date_args, offset, - expected_date_args): + ("initial_date_args", "offset", "expected_date_args"), + [ + ((1, 1, 1), YearBegin(), (2, 1, 1)), + ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), + ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), + ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), + ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), + ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), + ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_year_begin(calendar, initial_date_args, offset, expected_date_args): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) result = initial + offset @@ -542,20 +608,20 @@ def test_add_year_begin(calendar, initial_date_args, offset, @pytest.mark.parametrize( - ('initial_date_args', 'offset', 'expected_year_month', - 'expected_sub_day'), - [((1, 1, 1), YearEnd(), (1, 12), ()), - ((1, 1, 1), YearEnd(n=2), (2, 12), ()), - ((1, 1, 1), YearEnd(month=1), (1, 1), ()), - ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), - ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), - ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), - ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5))], - ids=_id_func + ("initial_date_args", "offset", "expected_year_month", "expected_sub_day"), + [ + ((1, 1, 1), YearEnd(), (1, 12), ()), + ((1, 1, 1), YearEnd(n=2), (2, 12), ()), + ((1, 1, 1), YearEnd(month=1), (1, 1), ()), + ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), + ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5)), + ], + ids=_id_func, ) def test_add_year_end( - calendar, initial_date_args, offset, expected_year_month, - expected_sub_day + calendar, initial_date_args, offset, expected_year_month, expected_sub_day ): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) @@ -564,59 +630,74 @@ def test_add_year_end( reference = date_type(*reference_args) # Here the days at the end of each month varies based on the calendar used - expected_date_args = (expected_year_month + - (_days_in_month(reference),) + expected_sub_day) + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) expected = date_type(*expected_date_args) assert result == expected @pytest.mark.parametrize( - ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', - 'expected_sub_day'), - [((1, 12), (), YearEnd(), (2, 12), ()), - ((1, 12), (), YearEnd(n=2), (3, 12), ()), - ((2, 12), (), YearEnd(n=-1), (1, 12), ()), - ((3, 12), (), YearEnd(n=-2), (1, 12), ()), - ((1, 1), (), YearEnd(month=2), (1, 2), ()), - ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), - ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5))], - ids=_id_func + ( + "initial_year_month", + "initial_sub_day", + "offset", + "expected_year_month", + "expected_sub_day", + ), + [ + ((1, 12), (), YearEnd(), (2, 12), ()), + ((1, 12), (), YearEnd(n=2), (3, 12), ()), + ((2, 12), (), YearEnd(n=-1), (1, 12), ()), + ((3, 12), (), YearEnd(n=-2), (1, 12), ()), + ((1, 1), (), YearEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), + ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5)), + ], + ids=_id_func, ) def test_add_year_end_onOffset( - calendar, initial_year_month, initial_sub_day, offset, expected_year_month, - expected_sub_day + calendar, + initial_year_month, + initial_sub_day, + offset, + expected_year_month, + expected_sub_day, ): date_type = get_date_type(calendar) reference_args = initial_year_month + (1,) reference = date_type(*reference_args) - initial_date_args = (initial_year_month + (_days_in_month(reference),) + - initial_sub_day) + initial_date_args = ( + initial_year_month + (_days_in_month(reference),) + initial_sub_day + ) initial = date_type(*initial_date_args) result = initial + offset reference_args = expected_year_month + (1,) reference = date_type(*reference_args) # Here the days at the end of each month varies based on the calendar used - expected_date_args = (expected_year_month + - (_days_in_month(reference),) + expected_sub_day) + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) expected = date_type(*expected_date_args) assert result == expected @pytest.mark.parametrize( - ('initial_date_args', 'offset', 'expected_date_args'), - [((1, 1, 1), QuarterBegin(), (1, 3, 1)), - ((1, 1, 1), QuarterBegin(n=2), (1, 6, 1)), - ((1, 1, 1), QuarterBegin(month=2), (1, 2, 1)), - ((1, 1, 7), QuarterBegin(n=2), (1, 6, 1)), - ((2, 2, 1), QuarterBegin(n=-1), (1, 12, 1)), - ((1, 3, 2), QuarterBegin(n=-1), (1, 3, 1)), - ((1, 1, 1, 5, 5, 5, 5), QuarterBegin(), (1, 3, 1, 5, 5, 5, 5)), - ((2, 1, 1, 5, 5, 5, 5), QuarterBegin(n=-1), (1, 12, 1, 5, 5, 5, 5))], - ids=_id_func -) -def test_add_quarter_begin(calendar, initial_date_args, offset, - expected_date_args): + ("initial_date_args", "offset", "expected_date_args"), + [ + ((1, 1, 1), QuarterBegin(), (1, 3, 1)), + ((1, 1, 1), QuarterBegin(n=2), (1, 6, 1)), + ((1, 1, 1), QuarterBegin(month=2), (1, 2, 1)), + ((1, 1, 7), QuarterBegin(n=2), (1, 6, 1)), + ((2, 2, 1), QuarterBegin(n=-1), (1, 12, 1)), + ((1, 3, 2), QuarterBegin(n=-1), (1, 3, 1)), + ((1, 1, 1, 5, 5, 5, 5), QuarterBegin(), (1, 3, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), QuarterBegin(n=-1), (1, 12, 1, 5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_quarter_begin(calendar, initial_date_args, offset, expected_date_args): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) result = initial + offset @@ -625,20 +706,20 @@ def test_add_quarter_begin(calendar, initial_date_args, offset, @pytest.mark.parametrize( - ('initial_date_args', 'offset', 'expected_year_month', - 'expected_sub_day'), - [((1, 1, 1), QuarterEnd(), (1, 3), ()), - ((1, 1, 1), QuarterEnd(n=2), (1, 6), ()), - ((1, 1, 1), QuarterEnd(month=1), (1, 1), ()), - ((2, 3, 1), QuarterEnd(n=-1), (1, 12), ()), - ((1, 3, 1), QuarterEnd(n=-1, month=2), (1, 2), ()), - ((1, 1, 1, 5, 5, 5, 5), QuarterEnd(), (1, 3), (5, 5, 5, 5)), - ((1, 1, 1, 5, 5, 5, 5), QuarterEnd(n=2), (1, 6), (5, 5, 5, 5))], - ids=_id_func + ("initial_date_args", "offset", "expected_year_month", "expected_sub_day"), + [ + ((1, 1, 1), QuarterEnd(), (1, 3), ()), + ((1, 1, 1), QuarterEnd(n=2), (1, 6), ()), + ((1, 1, 1), QuarterEnd(month=1), (1, 1), ()), + ((2, 3, 1), QuarterEnd(n=-1), (1, 12), ()), + ((1, 3, 1), QuarterEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), QuarterEnd(), (1, 3), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), QuarterEnd(n=2), (1, 6), (5, 5, 5, 5)), + ], + ids=_id_func, ) def test_add_quarter_end( - calendar, initial_date_args, offset, expected_year_month, - expected_sub_day + calendar, initial_date_args, offset, expected_year_month, expected_sub_day ): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) @@ -647,66 +728,82 @@ def test_add_quarter_end( reference = date_type(*reference_args) # Here the days at the end of each month varies based on the calendar used - expected_date_args = (expected_year_month + - (_days_in_month(reference),) + expected_sub_day) + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) expected = date_type(*expected_date_args) assert result == expected @pytest.mark.parametrize( - ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', - 'expected_sub_day'), - [((1, 12), (), QuarterEnd(), (2, 3), ()), - ((1, 12), (), QuarterEnd(n=2), (2, 6), ()), - ((1, 12), (), QuarterEnd(n=-1), (1, 9), ()), - ((1, 12), (), QuarterEnd(n=-2), (1, 6), ()), - ((1, 1), (), QuarterEnd(month=2), (1, 2), ()), - ((1, 12), (5, 5, 5, 5), QuarterEnd(), (2, 3), (5, 5, 5, 5)), - ((1, 12), (5, 5, 5, 5), QuarterEnd(n=-1), (1, 9), (5, 5, 5, 5))], - ids=_id_func + ( + "initial_year_month", + "initial_sub_day", + "offset", + "expected_year_month", + "expected_sub_day", + ), + [ + ((1, 12), (), QuarterEnd(), (2, 3), ()), + ((1, 12), (), QuarterEnd(n=2), (2, 6), ()), + ((1, 12), (), QuarterEnd(n=-1), (1, 9), ()), + ((1, 12), (), QuarterEnd(n=-2), (1, 6), ()), + ((1, 1), (), QuarterEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), QuarterEnd(), (2, 3), (5, 5, 5, 5)), + ((1, 12), (5, 5, 5, 5), QuarterEnd(n=-1), (1, 9), (5, 5, 5, 5)), + ], + ids=_id_func, ) def test_add_quarter_end_onOffset( - calendar, initial_year_month, initial_sub_day, offset, expected_year_month, - expected_sub_day + calendar, + initial_year_month, + initial_sub_day, + offset, + expected_year_month, + expected_sub_day, ): date_type = get_date_type(calendar) reference_args = initial_year_month + (1,) reference = date_type(*reference_args) - initial_date_args = (initial_year_month + (_days_in_month(reference),) + - initial_sub_day) + initial_date_args = ( + initial_year_month + (_days_in_month(reference),) + initial_sub_day + ) initial = date_type(*initial_date_args) result = initial + offset reference_args = expected_year_month + (1,) reference = date_type(*reference_args) # Here the days at the end of each month varies based on the calendar used - expected_date_args = (expected_year_month + - (_days_in_month(reference),) + expected_sub_day) + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) expected = date_type(*expected_date_args) assert result == expected # Note for all sub-monthly offsets, pandas always returns True for onOffset @pytest.mark.parametrize( - ('date_args', 'offset', 'expected'), - [((1, 1, 1), MonthBegin(), True), - ((1, 1, 1, 1), MonthBegin(), True), - ((1, 1, 5), MonthBegin(), False), - ((1, 1, 5), MonthEnd(), False), - ((1, 3, 1), QuarterBegin(), True), - ((1, 3, 1, 1), QuarterBegin(), True), - ((1, 3, 5), QuarterBegin(), False), - ((1, 12, 1), QuarterEnd(), False), - ((1, 1, 1), YearBegin(), True), - ((1, 1, 1, 1), YearBegin(), True), - ((1, 1, 5), YearBegin(), False), - ((1, 12, 1), YearEnd(), False), - ((1, 1, 1), Day(), True), - ((1, 1, 1, 1), Day(), True), - ((1, 1, 1), Hour(), True), - ((1, 1, 1), Minute(), True), - ((1, 1, 1), Second(), True)], - ids=_id_func + ("date_args", "offset", "expected"), + [ + ((1, 1, 1), MonthBegin(), True), + ((1, 1, 1, 1), MonthBegin(), True), + ((1, 1, 5), MonthBegin(), False), + ((1, 1, 5), MonthEnd(), False), + ((1, 3, 1), QuarterBegin(), True), + ((1, 3, 1, 1), QuarterBegin(), True), + ((1, 3, 5), QuarterBegin(), False), + ((1, 12, 1), QuarterEnd(), False), + ((1, 1, 1), YearBegin(), True), + ((1, 1, 1, 1), YearBegin(), True), + ((1, 1, 5), YearBegin(), False), + ((1, 12, 1), YearEnd(), False), + ((1, 1, 1), Day(), True), + ((1, 1, 1, 1), Day(), True), + ((1, 1, 1), Hour(), True), + ((1, 1, 1), Minute(), True), + ((1, 1, 1), Second(), True), + ], + ids=_id_func, ) def test_onOffset(calendar, date_args, offset, expected): date_type = get_date_type(calendar) @@ -716,59 +813,62 @@ def test_onOffset(calendar, date_args, offset, expected): @pytest.mark.parametrize( - ('year_month_args', 'sub_day_args', 'offset'), - [((1, 1), (), MonthEnd()), - ((1, 1), (1,), MonthEnd()), - ((1, 12), (), QuarterEnd()), - ((1, 1), (), QuarterEnd(month=1)), - ((1, 12), (), YearEnd()), - ((1, 1), (), YearEnd(month=1))], - ids=_id_func + ("year_month_args", "sub_day_args", "offset"), + [ + ((1, 1), (), MonthEnd()), + ((1, 1), (1,), MonthEnd()), + ((1, 12), (), QuarterEnd()), + ((1, 1), (), QuarterEnd(month=1)), + ((1, 12), (), YearEnd()), + ((1, 1), (), YearEnd(month=1)), + ], + ids=_id_func, ) def test_onOffset_month_or_quarter_or_year_end( - calendar, year_month_args, sub_day_args, offset): + calendar, year_month_args, sub_day_args, offset +): date_type = get_date_type(calendar) reference_args = year_month_args + (1,) reference = date_type(*reference_args) - date_args = (year_month_args + (_days_in_month(reference),) + - sub_day_args) + date_args = year_month_args + (_days_in_month(reference),) + sub_day_args date = date_type(*date_args) result = offset.onOffset(date) assert result @pytest.mark.parametrize( - ('offset', 'initial_date_args', 'partial_expected_date_args'), - [(YearBegin(), (1, 3, 1), (2, 1)), - (YearBegin(), (1, 1, 1), (1, 1)), - (YearBegin(n=2), (1, 3, 1), (2, 1)), - (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), - (YearEnd(), (1, 3, 1), (1, 12)), - (YearEnd(n=2), (1, 3, 1), (1, 12)), - (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), - (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), - (QuarterBegin(), (1, 3, 2), (1, 6)), - (QuarterBegin(), (1, 4, 1), (1, 6)), - (QuarterBegin(n=2), (1, 4, 1), (1, 6)), - (QuarterBegin(n=2, month=2), (1, 4, 1), (1, 5)), - (QuarterEnd(), (1, 3, 1), (1, 3)), - (QuarterEnd(n=2), (1, 3, 1), (1, 3)), - (QuarterEnd(n=2, month=2), (1, 3, 1), (1, 5)), - (QuarterEnd(n=2, month=4), (1, 4, 30), (1, 4)), - (MonthBegin(), (1, 3, 2), (1, 4)), - (MonthBegin(), (1, 3, 1), (1, 3)), - (MonthBegin(n=2), (1, 3, 2), (1, 4)), - (MonthEnd(), (1, 3, 2), (1, 3)), - (MonthEnd(), (1, 4, 30), (1, 4)), - (MonthEnd(n=2), (1, 3, 2), (1, 3)), - (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), - (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), - (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), - (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], - ids=_id_func -) -def test_rollforward(calendar, offset, initial_date_args, - partial_expected_date_args): + ("offset", "initial_date_args", "partial_expected_date_args"), + [ + (YearBegin(), (1, 3, 1), (2, 1)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (2, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(), (1, 3, 1), (1, 12)), + (YearEnd(n=2), (1, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (QuarterBegin(), (1, 3, 2), (1, 6)), + (QuarterBegin(), (1, 4, 1), (1, 6)), + (QuarterBegin(n=2), (1, 4, 1), (1, 6)), + (QuarterBegin(n=2, month=2), (1, 4, 1), (1, 5)), + (QuarterEnd(), (1, 3, 1), (1, 3)), + (QuarterEnd(n=2), (1, 3, 1), (1, 3)), + (QuarterEnd(n=2, month=2), (1, 3, 1), (1, 5)), + (QuarterEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 4)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 4)), + (MonthEnd(), (1, 3, 2), (1, 3)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (MonthEnd(n=2), (1, 3, 2), (1, 3)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + ], + ids=_id_func, +) +def test_rollforward(calendar, offset, initial_date_args, partial_expected_date_args): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) if isinstance(offset, (MonthBegin, QuarterBegin, YearBegin)): @@ -776,8 +876,7 @@ def test_rollforward(calendar, offset, initial_date_args, elif isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)): reference_args = partial_expected_date_args + (1,) reference = date_type(*reference_args) - expected_date_args = (partial_expected_date_args + - (_days_in_month(reference),)) + expected_date_args = partial_expected_date_args + (_days_in_month(reference),) else: expected_date_args = partial_expected_date_args expected = date_type(*expected_date_args) @@ -786,38 +885,39 @@ def test_rollforward(calendar, offset, initial_date_args, @pytest.mark.parametrize( - ('offset', 'initial_date_args', 'partial_expected_date_args'), - [(YearBegin(), (1, 3, 1), (1, 1)), - (YearBegin(n=2), (1, 3, 1), (1, 1)), - (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), - (YearBegin(), (1, 1, 1), (1, 1)), - (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), - (YearEnd(), (2, 3, 1), (1, 12)), - (YearEnd(n=2), (2, 3, 1), (1, 12)), - (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), - (YearEnd(month=4), (1, 4, 30), (1, 4)), - (QuarterBegin(), (1, 3, 2), (1, 3)), - (QuarterBegin(), (1, 4, 1), (1, 3)), - (QuarterBegin(n=2), (1, 4, 1), (1, 3)), - (QuarterBegin(n=2, month=2), (1, 4, 1), (1, 2)), - (QuarterEnd(), (2, 3, 1), (1, 12)), - (QuarterEnd(n=2), (2, 3, 1), (1, 12)), - (QuarterEnd(n=2, month=2), (2, 3, 1), (2, 2)), - (QuarterEnd(n=2, month=4), (1, 4, 30), (1, 4)), - (MonthBegin(), (1, 3, 2), (1, 3)), - (MonthBegin(n=2), (1, 3, 2), (1, 3)), - (MonthBegin(), (1, 3, 1), (1, 3)), - (MonthEnd(), (1, 3, 2), (1, 2)), - (MonthEnd(n=2), (1, 3, 2), (1, 2)), - (MonthEnd(), (1, 4, 30), (1, 4)), - (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), - (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), - (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), - (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], - ids=_id_func -) -def test_rollback(calendar, offset, initial_date_args, - partial_expected_date_args): + ("offset", "initial_date_args", "partial_expected_date_args"), + [ + (YearBegin(), (1, 3, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), + (YearEnd(), (2, 3, 1), (1, 12)), + (YearEnd(n=2), (2, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (YearEnd(month=4), (1, 4, 30), (1, 4)), + (QuarterBegin(), (1, 3, 2), (1, 3)), + (QuarterBegin(), (1, 4, 1), (1, 3)), + (QuarterBegin(n=2), (1, 4, 1), (1, 3)), + (QuarterBegin(n=2, month=2), (1, 4, 1), (1, 2)), + (QuarterEnd(), (2, 3, 1), (1, 12)), + (QuarterEnd(n=2), (2, 3, 1), (1, 12)), + (QuarterEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (QuarterEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 3)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthEnd(), (1, 3, 2), (1, 2)), + (MonthEnd(n=2), (1, 3, 2), (1, 2)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + ], + ids=_id_func, +) +def test_rollback(calendar, offset, initial_date_args, partial_expected_date_args): date_type = get_date_type(calendar) initial = date_type(*initial_date_args) if isinstance(offset, (MonthBegin, QuarterBegin, YearBegin)): @@ -825,8 +925,7 @@ def test_rollback(calendar, offset, initial_date_args, elif isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)): reference_args = partial_expected_date_args + (1,) reference = date_type(*reference_args) - expected_date_args = (partial_expected_date_args + - (_days_in_month(reference),)) + expected_date_args = partial_expected_date_args + (_days_in_month(reference),) else: expected_date_args = partial_expected_date_args expected = date_type(*expected_date_args) @@ -835,45 +934,135 @@ def test_rollback(calendar, offset, initial_date_args, _CFTIME_RANGE_TESTS = [ - ('0001-01-01', '0001-01-04', None, 'D', None, False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ('0001-01-01', '0001-01-04', None, 'D', 'left', False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3)]), - ('0001-01-01', '0001-01-04', None, 'D', 'right', False, - [(1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, False, - [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)]), - ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, True, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ('0001-01-01', None, 4, 'D', None, False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - (None, '0001-01-04', 4, 'D', None, False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ((1, 1, 1), '0001-01-04', None, 'D', None, False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ((1, 1, 1), (1, 1, 4), None, 'D', None, False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ('0001-01-30', '0011-02-01', None, '3AS-JUN', None, False, - [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)]), - ('0001-01-04', '0001-01-01', None, 'D', None, False, - []), - ('0010', None, 4, YearBegin(n=-2), None, False, - [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)]), - ('0001-01-01', '0001-01-04', 4, None, None, False, - [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), - ('0001-06-01', None, 4, '3QS-JUN', None, False, - [(1, 6, 1), (2, 3, 1), (2, 12, 1), (3, 9, 1)]) + ( + "0001-01-01", + "0001-01-04", + None, + "D", + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01", + "0001-01-04", + None, + "D", + "left", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3)], + ), + ( + "0001-01-01", + "0001-01-04", + None, + "D", + "right", + False, + [(1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01T01:00:00", + "0001-01-04", + None, + "D", + None, + False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)], + ), + ( + "0001-01-01T01:00:00", + "0001-01-04", + None, + "D", + None, + True, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01", + None, + 4, + "D", + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + None, + "0001-01-04", + 4, + "D", + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + (1, 1, 1), + "0001-01-04", + None, + "D", + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + (1, 1, 1), + (1, 1, 4), + None, + "D", + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-30", + "0011-02-01", + None, + "3AS-JUN", + None, + False, + [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)], + ), + ("0001-01-04", "0001-01-01", None, "D", None, False, []), + ( + "0010", + None, + 4, + YearBegin(n=-2), + None, + False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)], + ), + ( + "0001-01-01", + "0001-01-04", + 4, + None, + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-06-01", + None, + 4, + "3QS-JUN", + None, + False, + [(1, 6, 1), (2, 3, 1), (2, 12, 1), (3, 9, 1)], + ), ] @pytest.mark.parametrize( - ('start', 'end', 'periods', 'freq', 'closed', 'normalize', - 'expected_date_args'), - _CFTIME_RANGE_TESTS, ids=_id_func + ("start", "end", "periods", "freq", "closed", "normalize", "expected_date_args"), + _CFTIME_RANGE_TESTS, + ids=_id_func, ) def test_cftime_range( - start, end, periods, freq, closed, normalize, calendar, - expected_date_args): + start, end, periods, freq, closed, normalize, calendar, expected_date_args +): date_type = get_date_type(calendar) expected_dates = [date_type(*args) for args in expected_date_args] @@ -883,8 +1072,14 @@ def test_cftime_range( end = date_type(*end) result = cftime_range( - start=start, end=end, periods=periods, freq=freq, closed=closed, - normalize=normalize, calendar=calendar) + start=start, + end=end, + periods=periods, + freq=freq, + closed=closed, + normalize=normalize, + calendar=calendar, + ) resulting_dates = result.values assert isinstance(result, CFTimeIndex) @@ -902,22 +1097,24 @@ def test_cftime_range( def test_cftime_range_name(): - result = cftime_range(start='2000', periods=4, name='foo') - assert result.name == 'foo' + result = cftime_range(start="2000", periods=4, name="foo") + assert result.name == "foo" - result = cftime_range(start='2000', periods=4) + result = cftime_range(start="2000", periods=4) assert result.name is None @pytest.mark.parametrize( - ('start', 'end', 'periods', 'freq', 'closed'), - [(None, None, 5, 'A', None), - ('2000', None, None, 'A', None), - (None, '2000', None, 'A', None), - ('2000', '2001', None, None, None), - (None, None, None, None, None), - ('2000', '2001', None, 'A', 'up'), - ('2000', '2001', 5, 'A', None)] + ("start", "end", "periods", "freq", "closed"), + [ + (None, None, 5, "A", None), + ("2000", None, None, "A", None), + (None, "2000", None, "A", None), + ("2000", "2001", None, None, None), + (None, None, None, None, None), + ("2000", "2001", None, "A", "up"), + ("2000", "2001", 5, "A", None), + ], ) def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): with pytest.raises(ValueError): @@ -925,73 +1122,70 @@ def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): _CALENDAR_SPECIFIC_MONTH_END_TESTS = [ - ('2M', 'noleap', - [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ('2M', 'all_leap', - [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ('2M', '360_day', - [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), - ('2M', 'standard', - [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ('2M', 'gregorian', - [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ('2M', 'julian', - [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]) + ("2M", "noleap", [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("2M", "all_leap", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("2M", "360_day", [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ("2M", "standard", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("2M", "gregorian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("2M", "julian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), ] @pytest.mark.parametrize( - ('freq', 'calendar', 'expected_month_day'), - _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func + ("freq", "calendar", "expected_month_day"), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, + ids=_id_func, ) def test_calendar_specific_month_end(freq, calendar, expected_month_day): year = 2000 # Use a leap-year to highlight calendar differences result = cftime_range( - start='2000-02', end='2001', freq=freq, calendar=calendar).values + start="2000-02", end="2001", freq=freq, calendar=calendar + ).values date_type = get_date_type(calendar) expected = [date_type(year, *args) for args in expected_month_day] np.testing.assert_equal(result, expected) @pytest.mark.parametrize( - ('calendar', 'start', 'end', 'expected_number_of_days'), - [('noleap', '2000', '2001', 365), - ('all_leap', '2000', '2001', 366), - ('360_day', '2000', '2001', 360), - ('standard', '2000', '2001', 366), - ('gregorian', '2000', '2001', 366), - ('julian', '2000', '2001', 366), - ('noleap', '2001', '2002', 365), - ('all_leap', '2001', '2002', 366), - ('360_day', '2001', '2002', 360), - ('standard', '2001', '2002', 365), - ('gregorian', '2001', '2002', 365), - ('julian', '2001', '2002', 365)] -) -def test_calendar_year_length( - calendar, start, end, expected_number_of_days): - result = cftime_range(start, end, freq='D', closed='left', - calendar=calendar) + ("calendar", "start", "end", "expected_number_of_days"), + [ + ("noleap", "2000", "2001", 365), + ("all_leap", "2000", "2001", 366), + ("360_day", "2000", "2001", 360), + ("standard", "2000", "2001", 366), + ("gregorian", "2000", "2001", 366), + ("julian", "2000", "2001", 366), + ("noleap", "2001", "2002", 365), + ("all_leap", "2001", "2002", 366), + ("360_day", "2001", "2002", 360), + ("standard", "2001", "2002", 365), + ("gregorian", "2001", "2002", 365), + ("julian", "2001", "2002", 365), + ], +) +def test_calendar_year_length(calendar, start, end, expected_number_of_days): + result = cftime_range(start, end, freq="D", closed="left", calendar=calendar) assert len(result) == expected_number_of_days -@pytest.mark.parametrize('freq', ['A', 'M', 'D']) +@pytest.mark.parametrize("freq", ["A", "M", "D"]) def test_dayofweek_after_cftime_range(freq): - pytest.importorskip('cftime', minversion='1.0.2.1') - result = cftime_range('2000-02-01', periods=3, freq=freq).dayofweek - expected = pd.date_range('2000-02-01', periods=3, freq=freq).dayofweek + pytest.importorskip("cftime", minversion="1.0.2.1") + result = cftime_range("2000-02-01", periods=3, freq=freq).dayofweek + expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofweek np.testing.assert_array_equal(result, expected) -@pytest.mark.parametrize('freq', ['A', 'M', 'D']) +@pytest.mark.parametrize("freq", ["A", "M", "D"]) def test_dayofyear_after_cftime_range(freq): - pytest.importorskip('cftime', minversion='1.0.2.1') - result = cftime_range('2000-02-01', periods=3, freq=freq).dayofyear - expected = pd.date_range('2000-02-01', periods=3, freq=freq).dayofyear + pytest.importorskip("cftime", minversion="1.0.2.1") + result = cftime_range("2000-02-01", periods=3, freq=freq).dayofyear + expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofyear np.testing.assert_array_equal(result, expected) def test_cftime_range_standard_calendar_refers_to_gregorian(): from cftime import DatetimeGregorian - result, = cftime_range('2000', periods=1) + + result, = cftime_range("2000", periods=1) assert isinstance(result, DatetimeGregorian) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 56c01fbdc28..fcc9acf75bb 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -6,60 +6,92 @@ import xarray as xr from xarray.coding.cftimeindex import ( - CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso, - _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601) + CFTimeIndex, + _parse_array_of_cftime_strings, + _parse_iso8601_with_reso, + _parsed_string_to_bounds, + assert_all_valid_date_type, + parse_iso8601, +) from xarray.tests import assert_array_equal, assert_identical from . import ( - has_cftime, has_cftime_1_0_2_1, has_cftime_or_netCDF4, raises_regex, - requires_cftime) + has_cftime, + has_cftime_1_0_2_1, + has_cftime_or_netCDF4, + raises_regex, + requires_cftime, +) from .test_coding_times import ( - _ALL_CALENDARS, _NON_STANDARD_CALENDARS, _all_cftime_date_types) + _ALL_CALENDARS, + _NON_STANDARD_CALENDARS, + _all_cftime_date_types, +) -def date_dict(year=None, month=None, day=None, - hour=None, minute=None, second=None): - return dict(year=year, month=month, day=day, hour=hour, - minute=minute, second=second) +def date_dict(year=None, month=None, day=None, hour=None, minute=None, second=None): + return dict( + year=year, month=month, day=day, hour=hour, minute=minute, second=second + ) ISO8601_STRING_TESTS = { - 'year': ('1999', date_dict(year='1999')), - 'month': ('199901', date_dict(year='1999', month='01')), - 'month-dash': ('1999-01', date_dict(year='1999', month='01')), - 'day': ('19990101', date_dict(year='1999', month='01', day='01')), - 'day-dash': ('1999-01-01', date_dict(year='1999', month='01', day='01')), - 'hour': ('19990101T12', date_dict( - year='1999', month='01', day='01', hour='12')), - 'hour-dash': ('1999-01-01T12', date_dict( - year='1999', month='01', day='01', hour='12')), - 'minute': ('19990101T1234', date_dict( - year='1999', month='01', day='01', hour='12', minute='34')), - 'minute-dash': ('1999-01-01T12:34', date_dict( - year='1999', month='01', day='01', hour='12', minute='34')), - 'second': ('19990101T123456', date_dict( - year='1999', month='01', day='01', hour='12', minute='34', - second='56')), - 'second-dash': ('1999-01-01T12:34:56', date_dict( - year='1999', month='01', day='01', hour='12', minute='34', - second='56')) + "year": ("1999", date_dict(year="1999")), + "month": ("199901", date_dict(year="1999", month="01")), + "month-dash": ("1999-01", date_dict(year="1999", month="01")), + "day": ("19990101", date_dict(year="1999", month="01", day="01")), + "day-dash": ("1999-01-01", date_dict(year="1999", month="01", day="01")), + "hour": ("19990101T12", date_dict(year="1999", month="01", day="01", hour="12")), + "hour-dash": ( + "1999-01-01T12", + date_dict(year="1999", month="01", day="01", hour="12"), + ), + "minute": ( + "19990101T1234", + date_dict(year="1999", month="01", day="01", hour="12", minute="34"), + ), + "minute-dash": ( + "1999-01-01T12:34", + date_dict(year="1999", month="01", day="01", hour="12", minute="34"), + ), + "second": ( + "19990101T123456", + date_dict( + year="1999", month="01", day="01", hour="12", minute="34", second="56" + ), + ), + "second-dash": ( + "1999-01-01T12:34:56", + date_dict( + year="1999", month="01", day="01", hour="12", minute="34", second="56" + ), + ), } -@pytest.mark.parametrize(('string', 'expected'), - list(ISO8601_STRING_TESTS.values()), - ids=list(ISO8601_STRING_TESTS.keys())) +@pytest.mark.parametrize( + ("string", "expected"), + list(ISO8601_STRING_TESTS.values()), + ids=list(ISO8601_STRING_TESTS.keys()), +) def test_parse_iso8601(string, expected): result = parse_iso8601(string) assert result == expected with pytest.raises(ValueError): - parse_iso8601(string + '3') - parse_iso8601(string + '.3') + parse_iso8601(string + "3") + parse_iso8601(string + ".3") -_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', - '366_day', 'gregorian', 'proleptic_gregorian'] +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", +] @pytest.fixture(params=_CFTIME_CALENDARS) @@ -69,15 +101,23 @@ def date_type(request): @pytest.fixture def index(date_type): - dates = [date_type(1, 1, 1), date_type(1, 2, 1), - date_type(2, 1, 1), date_type(2, 2, 1)] + dates = [ + date_type(1, 1, 1), + date_type(1, 2, 1), + date_type(2, 1, 1), + date_type(2, 2, 1), + ] return CFTimeIndex(dates) @pytest.fixture def monotonic_decreasing_index(date_type): - dates = [date_type(2, 2, 1), date_type(2, 1, 1), - date_type(1, 2, 1), date_type(1, 1, 1)] + dates = [ + date_type(2, 2, 1), + date_type(2, 1, 1), + date_type(1, 2, 1), + date_type(1, 1, 1), + ] return CFTimeIndex(dates) @@ -89,8 +129,7 @@ def length_one_index(date_type): @pytest.fixture def da(index): - return xr.DataArray([1, 2, 3, 4], coords=[index], - dims=['time']) + return xr.DataArray([1, 2, 3, 4], coords=[index], dims=["time"]) @pytest.fixture @@ -106,6 +145,7 @@ def df(index): @pytest.fixture def feb_days(date_type): import cftime + if date_type is cftime.DatetimeAllLeap: return 29 elif date_type is cftime.Datetime360Day: @@ -117,6 +157,7 @@ def feb_days(date_type): @pytest.fixture def dec_days(date_type): import cftime + if date_type is cftime.Datetime360Day: return 30 else: @@ -125,80 +166,87 @@ def dec_days(date_type): @pytest.fixture def index_with_name(date_type): - dates = [date_type(1, 1, 1), date_type(1, 2, 1), - date_type(2, 1, 1), date_type(2, 2, 1)] - return CFTimeIndex(dates, name='foo') + dates = [ + date_type(1, 1, 1), + date_type(1, 2, 1), + date_type(2, 1, 1), + date_type(2, 2, 1), + ] + return CFTimeIndex(dates, name="foo") -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize( - ('name', 'expected_name'), - [('bar', 'bar'), - (None, 'foo')]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize(("name", "expected_name"), [("bar", "bar"), (None, "foo")]) def test_constructor_with_name(index_with_name, name, expected_name): result = CFTimeIndex(index_with_name, name=name).name assert result == expected_name -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_assert_all_valid_date_type(date_type, index): import cftime + if date_type is cftime.DatetimeNoLeap: mixed_date_types = np.array( - [date_type(1, 1, 1), - cftime.DatetimeAllLeap(1, 2, 1)]) + [date_type(1, 1, 1), cftime.DatetimeAllLeap(1, 2, 1)] + ) else: mixed_date_types = np.array( - [date_type(1, 1, 1), - cftime.DatetimeNoLeap(1, 2, 1)]) + [date_type(1, 1, 1), cftime.DatetimeNoLeap(1, 2, 1)] + ) with pytest.raises(TypeError): assert_all_valid_date_type(mixed_date_types) with pytest.raises(TypeError): assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)])) - assert_all_valid_date_type( - np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) + assert_all_valid_date_type(np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize(('field', 'expected'), [ - ('year', [1, 1, 2, 2]), - ('month', [1, 2, 1, 2]), - ('day', [1, 1, 1, 1]), - ('hour', [0, 0, 0, 0]), - ('minute', [0, 0, 0, 0]), - ('second', [0, 0, 0, 0]), - ('microsecond', [0, 0, 0, 0])]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + ("field", "expected"), + [ + ("year", [1, 1, 2, 2]), + ("month", [1, 2, 1, 2]), + ("day", [1, 1, 1, 1]), + ("hour", [0, 0, 0, 0]), + ("minute", [0, 0, 0, 0]), + ("second", [0, 0, 0, 0]), + ("microsecond", [0, 0, 0, 0]), + ], +) def test_cftimeindex_field_accessors(index, field, expected): result = getattr(index, field) assert_array_equal(result, expected) -@pytest.mark.skipif(not has_cftime_1_0_2_1, - reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_1_0_2_1, reason="cftime not installed") def test_cftimeindex_dayofyear_accessor(index): result = index.dayofyear expected = [date.dayofyr for date in index] assert_array_equal(result, expected) -@pytest.mark.skipif(not has_cftime_1_0_2_1, - reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_1_0_2_1, reason="cftime not installed") def test_cftimeindex_dayofweek_accessor(index): result = index.dayofweek expected = [date.dayofwk for date in index] assert_array_equal(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize(('string', 'date_args', 'reso'), [ - ('1999', (1999, 1, 1), 'year'), - ('199902', (1999, 2, 1), 'month'), - ('19990202', (1999, 2, 2), 'day'), - ('19990202T01', (1999, 2, 2, 1), 'hour'), - ('19990202T0101', (1999, 2, 2, 1, 1), 'minute'), - ('19990202T010156', (1999, 2, 2, 1, 1, 56), 'second')]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + ("string", "date_args", "reso"), + [ + ("1999", (1999, 1, 1), "year"), + ("199902", (1999, 2, 1), "month"), + ("19990202", (1999, 2, 2), "day"), + ("19990202T01", (1999, 2, 2, 1), "hour"), + ("19990202T0101", (1999, 2, 2, 1, 1), "minute"), + ("19990202T010156", (1999, 2, 2, 1, 1, 56), "second"), + ], +) def test_parse_iso8601_with_reso(date_type, string, date_args, reso): expected_date = date_type(*date_args) expected_reso = reso @@ -207,193 +255,193 @@ def test_parse_iso8601_with_reso(date_type, string, date_args, reso): assert result_reso == expected_reso -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_parse_string_to_bounds_year(date_type, dec_days): parsed = date_type(2, 2, 10, 6, 2, 8, 1) expected_start = date_type(2, 1, 1) expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) - result_start, result_end = _parsed_string_to_bounds( - date_type, 'year', parsed) + result_start, result_end = _parsed_string_to_bounds(date_type, "year", parsed) assert result_start == expected_start assert result_end == expected_end -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_parse_string_to_bounds_month_feb(date_type, feb_days): parsed = date_type(2, 2, 10, 6, 2, 8, 1) expected_start = date_type(2, 2, 1) expected_end = date_type(2, 2, feb_days, 23, 59, 59, 999999) - result_start, result_end = _parsed_string_to_bounds( - date_type, 'month', parsed) + result_start, result_end = _parsed_string_to_bounds(date_type, "month", parsed) assert result_start == expected_start assert result_end == expected_end -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_parse_string_to_bounds_month_dec(date_type, dec_days): parsed = date_type(2, 12, 1) expected_start = date_type(2, 12, 1) expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) - result_start, result_end = _parsed_string_to_bounds( - date_type, 'month', parsed) + result_start, result_end = _parsed_string_to_bounds(date_type, "month", parsed) assert result_start == expected_start assert result_end == expected_end -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize(('reso', 'ex_start_args', 'ex_end_args'), [ - ('day', (2, 2, 10), (2, 2, 10, 23, 59, 59, 999999)), - ('hour', (2, 2, 10, 6), (2, 2, 10, 6, 59, 59, 999999)), - ('minute', (2, 2, 10, 6, 2), (2, 2, 10, 6, 2, 59, 999999)), - ('second', (2, 2, 10, 6, 2, 8), (2, 2, 10, 6, 2, 8, 999999))]) -def test_parsed_string_to_bounds_sub_monthly(date_type, reso, - ex_start_args, ex_end_args): +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + ("reso", "ex_start_args", "ex_end_args"), + [ + ("day", (2, 2, 10), (2, 2, 10, 23, 59, 59, 999999)), + ("hour", (2, 2, 10, 6), (2, 2, 10, 6, 59, 59, 999999)), + ("minute", (2, 2, 10, 6, 2), (2, 2, 10, 6, 2, 59, 999999)), + ("second", (2, 2, 10, 6, 2, 8), (2, 2, 10, 6, 2, 8, 999999)), + ], +) +def test_parsed_string_to_bounds_sub_monthly( + date_type, reso, ex_start_args, ex_end_args +): parsed = date_type(2, 2, 10, 6, 2, 8, 123456) expected_start = date_type(*ex_start_args) expected_end = date_type(*ex_end_args) - result_start, result_end = _parsed_string_to_bounds( - date_type, reso, parsed) + result_start, result_end = _parsed_string_to_bounds(date_type, reso, parsed) assert result_start == expected_start assert result_end == expected_end -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_parsed_string_to_bounds_raises(date_type): with pytest.raises(KeyError): - _parsed_string_to_bounds(date_type, 'a', date_type(1, 1, 1)) + _parsed_string_to_bounds(date_type, "a", date_type(1, 1, 1)) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_get_loc(date_type, index): - result = index.get_loc('0001') + result = index.get_loc("0001") assert result == slice(0, 2) result = index.get_loc(date_type(1, 2, 1)) assert result == 1 - result = index.get_loc('0001-02-01') + result = index.get_loc("0001-02-01") assert result == slice(1, 2) - with raises_regex(KeyError, '1234'): - index.get_loc('1234') + with raises_regex(KeyError, "1234"): + index.get_loc("1234") -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('kind', ['loc', 'getitem']) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("kind", ["loc", "getitem"]) def test_get_slice_bound(date_type, index, kind): - result = index.get_slice_bound('0001', 'left', kind) + result = index.get_slice_bound("0001", "left", kind) expected = 0 assert result == expected - result = index.get_slice_bound('0001', 'right', kind) + result = index.get_slice_bound("0001", "right", kind) expected = 2 assert result == expected - result = index.get_slice_bound( - date_type(1, 3, 1), 'left', kind) + result = index.get_slice_bound(date_type(1, 3, 1), "left", kind) expected = 2 assert result == expected - result = index.get_slice_bound( - date_type(1, 3, 1), 'right', kind) + result = index.get_slice_bound(date_type(1, 3, 1), "right", kind) expected = 2 assert result == expected -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('kind', ['loc', 'getitem']) -def test_get_slice_bound_decreasing_index( - date_type, monotonic_decreasing_index, kind): - result = monotonic_decreasing_index.get_slice_bound('0001', 'left', kind) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("kind", ["loc", "getitem"]) +def test_get_slice_bound_decreasing_index(date_type, monotonic_decreasing_index, kind): + result = monotonic_decreasing_index.get_slice_bound("0001", "left", kind) expected = 2 assert result == expected - result = monotonic_decreasing_index.get_slice_bound('0001', 'right', kind) + result = monotonic_decreasing_index.get_slice_bound("0001", "right", kind) expected = 4 assert result == expected result = monotonic_decreasing_index.get_slice_bound( - date_type(1, 3, 1), 'left', kind) + date_type(1, 3, 1), "left", kind + ) expected = 2 assert result == expected result = monotonic_decreasing_index.get_slice_bound( - date_type(1, 3, 1), 'right', kind) + date_type(1, 3, 1), "right", kind + ) expected = 2 assert result == expected -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('kind', ['loc', 'getitem']) -def test_get_slice_bound_length_one_index( - date_type, length_one_index, kind): - result = length_one_index.get_slice_bound('0001', 'left', kind) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("kind", ["loc", "getitem"]) +def test_get_slice_bound_length_one_index(date_type, length_one_index, kind): + result = length_one_index.get_slice_bound("0001", "left", kind) expected = 0 assert result == expected - result = length_one_index.get_slice_bound('0001', 'right', kind) + result = length_one_index.get_slice_bound("0001", "right", kind) expected = 1 assert result == expected - result = length_one_index.get_slice_bound( - date_type(1, 3, 1), 'left', kind) + result = length_one_index.get_slice_bound(date_type(1, 3, 1), "left", kind) expected = 1 assert result == expected - result = length_one_index.get_slice_bound( - date_type(1, 3, 1), 'right', kind) + result = length_one_index.get_slice_bound(date_type(1, 3, 1), "right", kind) expected = 1 assert result == expected -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_string_slice_length_one_index(length_one_index): - da = xr.DataArray([1], coords=[length_one_index], dims=['time']) - result = da.sel(time=slice('0001', '0001')) + da = xr.DataArray([1], coords=[length_one_index], dims=["time"]) + result = da.sel(time=slice("0001", "0001")) assert_identical(result, da) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_date_type_property(date_type, index): assert index.date_type is date_type -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_contains(date_type, index): - assert '0001-01-01' in index - assert '0001' in index - assert '0003' not in index + assert "0001-01-01" in index + assert "0001" in index + assert "0003" not in index assert date_type(1, 1, 1) in index assert date_type(3, 1, 1) not in index -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_groupby(da): - result = da.groupby('time.month').sum('time') - expected = xr.DataArray([4, 6], coords=[[1, 2]], dims=['month']) + result = da.groupby("time.month").sum("time") + expected = xr.DataArray([4, 6], coords=[[1, 2]], dims=["month"]) assert_identical(result, expected) SEL_STRING_OR_LIST_TESTS = { - 'string': '0001', - 'string-slice': slice('0001-01-01', '0001-12-30'), # type: ignore - 'bool-list': [True, True, False, False] + "string": "0001", + "string-slice": slice("0001-01-01", "0001-12-30"), # type: ignore + "bool-list": [True, True, False, False], } -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_arg', list(SEL_STRING_OR_LIST_TESTS.values()), - ids=list(SEL_STRING_OR_LIST_TESTS.keys())) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_arg", + list(SEL_STRING_OR_LIST_TESTS.values()), + ids=list(SEL_STRING_OR_LIST_TESTS.keys()), +) def test_sel_string_or_list(da, index, sel_arg): - expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=["time"]) result = da.sel(time=sel_arg) assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_sel_date_slice_or_list(da, index, date_type): - expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=["time"]) result = da.sel(time=slice(date_type(1, 1, 1), date_type(1, 12, 30))) assert_identical(result, expected) @@ -401,18 +449,18 @@ def test_sel_date_slice_or_list(da, index, date_type): assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_sel_date_scalar(da, date_type, index): expected = xr.DataArray(1).assign_coords(time=index[0]) result = da.sel(time=date_type(1, 1, 1)) assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'nearest'}, - {'method': 'nearest', 'tolerance': timedelta(days=70)} -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "nearest"}, {"method": "nearest", "tolerance": timedelta(days=70)}], +) def test_sel_date_scalar_nearest(da, date_type, index, sel_kwargs): expected = xr.DataArray(2).assign_coords(time=index[1]) result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) @@ -423,11 +471,11 @@ def test_sel_date_scalar_nearest(da, date_type, index, sel_kwargs): assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'pad'}, - {'method': 'pad', 'tolerance': timedelta(days=365)} -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "pad"}, {"method": "pad", "tolerance": timedelta(days=365)}], +) def test_sel_date_scalar_pad(da, date_type, index, sel_kwargs): expected = xr.DataArray(2).assign_coords(time=index[1]) result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) @@ -438,11 +486,11 @@ def test_sel_date_scalar_pad(da, date_type, index, sel_kwargs): assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'backfill'}, - {'method': 'backfill', 'tolerance': timedelta(days=365)} -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "backfill"}, {"method": "backfill", "tolerance": timedelta(days=365)}], +) def test_sel_date_scalar_backfill(da, date_type, index, sel_kwargs): expected = xr.DataArray(3).assign_coords(time=index[2]) result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) @@ -453,86 +501,82 @@ def test_sel_date_scalar_backfill(da, date_type, index, sel_kwargs): assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'pad', 'tolerance': timedelta(days=20)}, - {'method': 'backfill', 'tolerance': timedelta(days=20)}, - {'method': 'nearest', 'tolerance': timedelta(days=20)}, -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [ + {"method": "pad", "tolerance": timedelta(days=20)}, + {"method": "backfill", "tolerance": timedelta(days=20)}, + {"method": "nearest", "tolerance": timedelta(days=20)}, + ], +) def test_sel_date_scalar_tolerance_raises(da, date_type, sel_kwargs): with pytest.raises(KeyError): da.sel(time=date_type(1, 5, 1), **sel_kwargs) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'nearest'}, - {'method': 'nearest', 'tolerance': timedelta(days=70)} -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "nearest"}, {"method": "nearest", "tolerance": timedelta(days=70)}], +) def test_sel_date_list_nearest(da, date_type, index, sel_kwargs): - expected = xr.DataArray( - [2, 2], coords=[[index[1], index[1]]], dims=['time']) - result = da.sel( - time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + expected = xr.DataArray([2, 2], coords=[[index[1], index[1]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) assert_identical(result, expected) - expected = xr.DataArray( - [2, 3], coords=[[index[1], index[2]]], dims=['time']) - result = da.sel( - time=[date_type(1, 3, 1), date_type(1, 12, 1)], **sel_kwargs) + expected = xr.DataArray([2, 3], coords=[[index[1], index[2]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 12, 1)], **sel_kwargs) assert_identical(result, expected) - expected = xr.DataArray( - [3, 3], coords=[[index[2], index[2]]], dims=['time']) - result = da.sel( - time=[date_type(1, 11, 1), date_type(1, 12, 1)], **sel_kwargs) + expected = xr.DataArray([3, 3], coords=[[index[2], index[2]]], dims=["time"]) + result = da.sel(time=[date_type(1, 11, 1), date_type(1, 12, 1)], **sel_kwargs) assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'pad'}, - {'method': 'pad', 'tolerance': timedelta(days=365)} -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "pad"}, {"method": "pad", "tolerance": timedelta(days=365)}], +) def test_sel_date_list_pad(da, date_type, index, sel_kwargs): - expected = xr.DataArray( - [2, 2], coords=[[index[1], index[1]]], dims=['time']) - result = da.sel( - time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + expected = xr.DataArray([2, 2], coords=[[index[1], index[1]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'backfill'}, - {'method': 'backfill', 'tolerance': timedelta(days=365)} -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "backfill"}, {"method": "backfill", "tolerance": timedelta(days=365)}], +) def test_sel_date_list_backfill(da, date_type, index, sel_kwargs): - expected = xr.DataArray( - [3, 3], coords=[[index[2], index[2]]], dims=['time']) - result = da.sel( - time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + expected = xr.DataArray([3, 3], coords=[[index[2], index[2]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) assert_identical(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('sel_kwargs', [ - {'method': 'pad', 'tolerance': timedelta(days=20)}, - {'method': 'backfill', 'tolerance': timedelta(days=20)}, - {'method': 'nearest', 'tolerance': timedelta(days=20)}, -]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize( + "sel_kwargs", + [ + {"method": "pad", "tolerance": timedelta(days=20)}, + {"method": "backfill", "tolerance": timedelta(days=20)}, + {"method": "nearest", "tolerance": timedelta(days=20)}, + ], +) def test_sel_date_list_tolerance_raises(da, date_type, sel_kwargs): with pytest.raises(KeyError): da.sel(time=[date_type(1, 2, 1), date_type(1, 5, 1)], **sel_kwargs) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_isel(da, index): expected = xr.DataArray(1).assign_coords(time=index[0]) result = da.isel(time=0) assert_identical(result, expected) - expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=["time"]) result = da.isel(time=[0, 1]) assert_identical(result, expected) @@ -544,13 +588,16 @@ def scalar_args(date_type): @pytest.fixture def range_args(date_type): - return ['0001', slice('0001-01-01', '0001-12-30'), - slice(None, '0001-12-30'), - slice(date_type(1, 1, 1), date_type(1, 12, 30)), - slice(None, date_type(1, 12, 30))] + return [ + "0001", + slice("0001-01-01", "0001-12-30"), + slice(None, "0001-12-30"), + slice(date_type(1, 1, 1), date_type(1, 12, 30)), + slice(None, date_type(1, 12, 30)), + ] -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_indexing_in_series_getitem(series, index, scalar_args, range_args): for arg in scalar_args: assert series[arg] == 1 @@ -560,7 +607,7 @@ def test_indexing_in_series_getitem(series, index, scalar_args, range_args): assert series[arg].equals(expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_indexing_in_series_loc(series, index, scalar_args, range_args): for arg in scalar_args: assert series.loc[arg] == 1 @@ -570,7 +617,7 @@ def test_indexing_in_series_loc(series, index, scalar_args, range_args): assert series.loc[arg].equals(expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_indexing_in_series_iloc(series, index): expected = 1 assert series.iloc[0] == expected @@ -579,15 +626,15 @@ def test_indexing_in_series_iloc(series, index): assert series.iloc[:2].equals(expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_series_dropna(index): - series = pd.Series([0., 1., np.nan, np.nan], index=index) + series = pd.Series([0.0, 1.0, np.nan, np.nan], index=index) expected = series.iloc[:2] result = series.dropna() assert result.equals(expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): expected = pd.Series([1], name=index[0]) for arg in scalar_args: @@ -600,7 +647,7 @@ def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): assert result.equals(expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_indexing_in_dataframe_iloc(df, index): expected = pd.Series([1], name=index[0]) result = df.iloc[0] @@ -612,78 +659,90 @@ def test_indexing_in_dataframe_iloc(df, index): assert result.equals(expected) -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") def test_concat_cftimeindex(date_type): da1 = xr.DataArray( - [1., 2.], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], - dims=['time']) + [1.0, 2.0], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], dims=["time"] + ) da2 = xr.DataArray( - [3., 4.], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], - dims=['time']) - da = xr.concat([da1, da2], dim='time') + [3.0, 4.0], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], dims=["time"] + ) + da = xr.concat([da1, da2], dim="time") if has_cftime: - assert isinstance(da.indexes['time'], CFTimeIndex) + assert isinstance(da.indexes["time"], CFTimeIndex) else: - assert isinstance(da.indexes['time'], pd.Index) - assert not isinstance(da.indexes['time'], CFTimeIndex) + assert isinstance(da.indexes["time"], pd.Index) + assert not isinstance(da.indexes["time"], CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_empty_cftimeindex(): index = CFTimeIndex([]) assert index.date_type is None -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftimeindex_add(index): date_type = index.date_type - expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), - date_type(2, 1, 2), date_type(2, 2, 2)] + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] expected = CFTimeIndex(expected_dates) result = index + timedelta(days=1) assert result.equals(expected) assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftimeindex_add_timedeltaindex(calendar): - a = xr.cftime_range('2000', periods=5, calendar=calendar) + a = xr.cftime_range("2000", periods=5, calendar=calendar) deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) result = a + deltas - expected = a.shift(2, 'D') + expected = a.shift(2, "D") assert result.equals(expected) assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftimeindex_radd(index): date_type = index.date_type - expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), - date_type(2, 1, 2), date_type(2, 2, 2)] + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] expected = CFTimeIndex(expected_dates) result = timedelta(days=1) + index assert result.equals(expected) assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_timedeltaindex_add_cftimeindex(calendar): - a = xr.cftime_range('2000', periods=5, calendar=calendar) + a = xr.cftime_range("2000", periods=5, calendar=calendar) deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) result = deltas + a - expected = a.shift(2, 'D') + expected = a.shift(2, "D") assert result.equals(expected) assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftimeindex_sub(index): date_type = index.date_type - expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), - date_type(2, 1, 2), date_type(2, 2, 2)] + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] expected = CFTimeIndex(expected_dates) result = index + timedelta(days=2) result = result - timedelta(days=1) @@ -691,76 +750,80 @@ def test_cftimeindex_sub(index): assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftimeindex_sub_cftimeindex(calendar): - a = xr.cftime_range('2000', periods=5, calendar=calendar) - b = a.shift(2, 'D') + a = xr.cftime_range("2000", periods=5, calendar=calendar) + b = a.shift(2, "D") result = b - a expected = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) assert result.equals(expected) assert isinstance(result, pd.TimedeltaIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftimeindex_sub_cftime_datetime(calendar): - a = xr.cftime_range('2000', periods=5, calendar=calendar) + a = xr.cftime_range("2000", periods=5, calendar=calendar) result = a - a[0] expected = pd.TimedeltaIndex([timedelta(days=i) for i in range(5)]) assert result.equals(expected) assert isinstance(result, pd.TimedeltaIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftime_datetime_sub_cftimeindex(calendar): - a = xr.cftime_range('2000', periods=5, calendar=calendar) + a = xr.cftime_range("2000", periods=5, calendar=calendar) result = a[0] - a expected = pd.TimedeltaIndex([timedelta(days=-i) for i in range(5)]) assert result.equals(expected) assert isinstance(result, pd.TimedeltaIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftimeindex_sub_timedeltaindex(calendar): - a = xr.cftime_range('2000', periods=5, calendar=calendar) + a = xr.cftime_range("2000", periods=5, calendar=calendar) deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) result = a - deltas - expected = a.shift(-2, 'D') + expected = a.shift(-2, "D") assert result.equals(expected) assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftimeindex_rsub(index): with pytest.raises(TypeError): timedelta(days=1) - index -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('freq', ['D', timedelta(days=1)]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("freq", ["D", timedelta(days=1)]) def test_cftimeindex_shift(index, freq): date_type = index.date_type - expected_dates = [date_type(1, 1, 3), date_type(1, 2, 3), - date_type(2, 1, 3), date_type(2, 2, 3)] + expected_dates = [ + date_type(1, 1, 3), + date_type(1, 2, 3), + date_type(2, 1, 3), + date_type(2, 2, 3), + ] expected = CFTimeIndex(expected_dates) result = index.shift(2, freq) assert result.equals(expected) assert isinstance(result, CFTimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftimeindex_shift_invalid_n(): - index = xr.cftime_range('2000', periods=3) + index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift('a', 'D') + index.shift("a", "D") -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_cftimeindex_shift_invalid_freq(): - index = xr.cftime_range('2000', periods=3) + index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): index.shift(1, 1) @@ -769,42 +832,44 @@ def test_cftimeindex_shift_invalid_freq(): def test_parse_array_of_cftime_strings(): from cftime import DatetimeNoLeap - strings = np.array([['2000-01-01', '2000-01-02'], - ['2000-01-03', '2000-01-04']]) + strings = np.array([["2000-01-01", "2000-01-02"], ["2000-01-03", "2000-01-04"]]) expected = np.array( - [[DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], - [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)]]) + [ + [DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], + [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)], + ] + ) result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) np.testing.assert_array_equal(result, expected) # Test scalar array case - strings = np.array('2000-01-01') + strings = np.array("2000-01-01") expected = np.array(DatetimeNoLeap(2000, 1, 1)) result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) np.testing.assert_array_equal(result, expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_strftime_of_cftime_array(calendar): - date_format = '%Y%m%d%H%M' - cf_values = xr.cftime_range('2000', periods=5, calendar=calendar) - dt_values = pd.date_range('2000', periods=5) + date_format = "%Y%m%d%H%M" + cf_values = xr.cftime_range("2000", periods=5, calendar=calendar) + dt_values = pd.date_range("2000", periods=5) expected = dt_values.strftime(date_format) result = cf_values.strftime(date_format) assert result.equals(expected) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) -@pytest.mark.parametrize('unsafe', [False, True]) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +@pytest.mark.parametrize("unsafe", [False, True]) def test_to_datetimeindex(calendar, unsafe): - index = xr.cftime_range('2000', periods=5, calendar=calendar) - expected = pd.date_range('2000', periods=5) + index = xr.cftime_range("2000", periods=5, calendar=calendar) + expected = pd.date_range("2000", periods=5) if calendar in _NON_STANDARD_CALENDARS and not unsafe: - with pytest.warns(RuntimeWarning, match='non-standard'): + with pytest.warns(RuntimeWarning, match="non-standard"): result = index.to_datetimeindex() else: result = index.to_datetimeindex(unsafe=unsafe) @@ -814,25 +879,25 @@ def test_to_datetimeindex(calendar, unsafe): assert isinstance(result, pd.DatetimeIndex) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_to_datetimeindex_out_of_range(calendar): - index = xr.cftime_range('0001', periods=5, calendar=calendar) - with pytest.raises(ValueError, match='0001'): + index = xr.cftime_range("0001", periods=5, calendar=calendar) + with pytest.raises(ValueError, match="0001"): index.to_datetimeindex() -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.parametrize('calendar', ['all_leap', '360_day']) +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.parametrize("calendar", ["all_leap", "360_day"]) def test_to_datetimeindex_feb_29(calendar): - index = xr.cftime_range('2001-02-28', periods=2, calendar=calendar) - with pytest.raises(ValueError, match='29'): + index = xr.cftime_range("2001-02-28", periods=2, calendar=calendar) + with pytest.raises(ValueError, match="29"): index.to_datetimeindex() -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') -@pytest.mark.xfail(reason='https://github.com/pandas-dev/pandas/issues/24263') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") +@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/24263") def test_multiindex(): - index = xr.cftime_range('2001-01-01', periods=100, calendar='360_day') + index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") mindex = pd.MultiIndex.from_arrays([index]) - assert mindex.get_loc('2001-01') == slice(0, 30) + assert mindex.get_loc("2001-01") == slice(0, 30) diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 108b303e0c0..bbc8dd82c95 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -7,8 +7,8 @@ import xarray as xr from xarray.core.resample_cftime import CFTimeGrouper -pytest.importorskip('cftime') -pytest.importorskip('pandas', minversion='0.24') +pytest.importorskip("cftime") +pytest.importorskip("pandas", minversion="0.24") # Create a list of pairs of similar-length initial and resample frequencies @@ -19,96 +19,132 @@ # These are used to test the cftime version of resample against pandas # with a standard calendar. FREQS = [ - ('8003D', '4001D'), - ('8003D', '16006D'), - ('8003D', '21AS'), - ('6H', '3H'), - ('6H', '12H'), - ('6H', '400T'), - ('3D', 'D'), - ('3D', '6D'), - ('11D', 'MS'), - ('3MS', 'MS'), - ('3MS', '6MS'), - ('3MS', '85D'), - ('7M', '3M'), - ('7M', '14M'), - ('7M', '2QS-APR'), - ('43QS-AUG', '21QS-AUG'), - ('43QS-AUG', '86QS-AUG'), - ('43QS-AUG', '11A-JUN'), - ('11Q-JUN', '5Q-JUN'), - ('11Q-JUN', '22Q-JUN'), - ('11Q-JUN', '51MS'), - ('3AS-MAR', 'AS-MAR'), - ('3AS-MAR', '6AS-MAR'), - ('3AS-MAR', '14Q-FEB'), - ('7A-MAY', '3A-MAY'), - ('7A-MAY', '14A-MAY'), - ('7A-MAY', '85M') + ("8003D", "4001D"), + ("8003D", "16006D"), + ("8003D", "21AS"), + ("6H", "3H"), + ("6H", "12H"), + ("6H", "400T"), + ("3D", "D"), + ("3D", "6D"), + ("11D", "MS"), + ("3MS", "MS"), + ("3MS", "6MS"), + ("3MS", "85D"), + ("7M", "3M"), + ("7M", "14M"), + ("7M", "2QS-APR"), + ("43QS-AUG", "21QS-AUG"), + ("43QS-AUG", "86QS-AUG"), + ("43QS-AUG", "11A-JUN"), + ("11Q-JUN", "5Q-JUN"), + ("11Q-JUN", "22Q-JUN"), + ("11Q-JUN", "51MS"), + ("3AS-MAR", "AS-MAR"), + ("3AS-MAR", "6AS-MAR"), + ("3AS-MAR", "14Q-FEB"), + ("7A-MAY", "3A-MAY"), + ("7A-MAY", "14A-MAY"), + ("7A-MAY", "85M"), ] def da(index): - return xr.DataArray(np.arange(100., 100. + index.size), - coords=[index], dims=['time']) + return xr.DataArray( + np.arange(100.0, 100.0 + index.size), coords=[index], dims=["time"] + ) -@pytest.mark.parametrize('freqs', FREQS, ids=lambda x: '{}->{}'.format(*x)) -@pytest.mark.parametrize('closed', [None, 'left', 'right']) -@pytest.mark.parametrize('label', [None, 'left', 'right']) -@pytest.mark.parametrize('base', [24, 31]) +@pytest.mark.parametrize("freqs", FREQS, ids=lambda x: "{}->{}".format(*x)) +@pytest.mark.parametrize("closed", [None, "left", "right"]) +@pytest.mark.parametrize("label", [None, "left", "right"]) +@pytest.mark.parametrize("base", [24, 31]) def test_resample(freqs, closed, label, base): initial_freq, resample_freq = freqs - start = '2000-01-01T12:07:01' + start = "2000-01-01T12:07:01" index_kwargs = dict(start=start, periods=5, freq=initial_freq) datetime_index = pd.date_range(**index_kwargs) cftime_index = xr.cftime_range(**index_kwargs) - loffset = '12H' + loffset = "12H" try: - da_datetime = da(datetime_index).resample( - time=resample_freq, closed=closed, label=label, base=base, - loffset=loffset).mean() + da_datetime = ( + da(datetime_index) + .resample( + time=resample_freq, + closed=closed, + label=label, + base=base, + loffset=loffset, + ) + .mean() + ) except ValueError: with pytest.raises(ValueError): da(cftime_index).resample( - time=resample_freq, closed=closed, label=label, base=base, - loffset=loffset).mean() + time=resample_freq, + closed=closed, + label=label, + base=base, + loffset=loffset, + ).mean() else: - da_cftime = da(cftime_index).resample( - time=resample_freq, closed=closed, - label=label, base=base, loffset=loffset).mean() - da_cftime['time'] = da_cftime.indexes['time'].to_datetimeindex() + da_cftime = ( + da(cftime_index) + .resample( + time=resample_freq, + closed=closed, + label=label, + base=base, + loffset=loffset, + ) + .mean() + ) + da_cftime["time"] = da_cftime.indexes["time"].to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) @pytest.mark.parametrize( - ('freq', 'expected'), - [('S', 'left'), ('T', 'left'), ('H', 'left'), ('D', 'left'), - ('M', 'right'), ('MS', 'left'), ('Q', 'right'), ('QS', 'left'), - ('A', 'right'), ('AS', 'left')]) + ("freq", "expected"), + [ + ("S", "left"), + ("T", "left"), + ("H", "left"), + ("D", "left"), + ("M", "right"), + ("MS", "left"), + ("Q", "right"), + ("QS", "left"), + ("A", "right"), + ("AS", "left"), + ], +) def test_closed_label_defaults(freq, expected): assert CFTimeGrouper(freq=freq).closed == expected assert CFTimeGrouper(freq=freq).label == expected -@pytest.mark.filterwarnings('ignore:Converting a CFTimeIndex') -@pytest.mark.parametrize('calendar', ['gregorian', 'noleap', 'all_leap', - '360_day', 'julian']) +@pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex") +@pytest.mark.parametrize( + "calendar", ["gregorian", "noleap", "all_leap", "360_day", "julian"] +) def test_calendars(calendar): # Limited testing for non-standard calendars - freq, closed, label, base = '8001T', None, None, 17 + freq, closed, label, base = "8001T", None, None, 17 loffset = datetime.timedelta(hours=12) - xr_index = xr.cftime_range(start='2004-01-01T12:07:01', periods=7, - freq='3D', calendar=calendar) - pd_index = pd.date_range(start='2004-01-01T12:07:01', periods=7, - freq='3D') - da_cftime = da(xr_index).resample( - time=freq, closed=closed, label=label, base=base, loffset=loffset - ).mean() - da_datetime = da(pd_index).resample( - time=freq, closed=closed, label=label, base=base, loffset=loffset - ).mean() - da_cftime['time'] = da_cftime.indexes['time'].to_datetimeindex() + xr_index = xr.cftime_range( + start="2004-01-01T12:07:01", periods=7, freq="3D", calendar=calendar + ) + pd_index = pd.date_range(start="2004-01-01T12:07:01", periods=7, freq="3D") + da_cftime = ( + da(xr_index) + .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) + .mean() + ) + da_datetime = ( + da(pd_index) + .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) + .mean() + ) + da_cftime["time"] = da_cftime.indexes["time"].to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 9f937ac7f5e..6cd584daa96 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -13,54 +13,57 @@ def test_CFMaskCoder_decode(): - original = xr.Variable(('x',), [0, -1, 1], {'_FillValue': -1}) - expected = xr.Variable(('x',), [0, np.nan, 1]) + original = xr.Variable(("x",), [0, -1, 1], {"_FillValue": -1}) + expected = xr.Variable(("x",), [0, np.nan, 1]) coder = variables.CFMaskCoder() encoded = coder.decode(original) assert_identical(expected, encoded) def test_CFMaskCoder_missing_value(): - expected = xr.DataArray(np.array([[26915, 27755, -9999, 27705], - [25595, -9999, 28315, -9999]]), - dims=['npts', 'ntimes'], - name='tmpk') - expected.attrs['missing_value'] = -9999 + expected = xr.DataArray( + np.array([[26915, 27755, -9999, 27705], [25595, -9999, 28315, -9999]]), + dims=["npts", "ntimes"], + name="tmpk", + ) + expected.attrs["missing_value"] = -9999 decoded = xr.decode_cf(expected.to_dataset()) encoded, _ = xr.conventions.cf_encoder(decoded, decoded.attrs) - assert_equal(encoded['tmpk'], expected.variable) + assert_equal(encoded["tmpk"], expected.variable) - decoded.tmpk.encoding['_FillValue'] = -9940 + decoded.tmpk.encoding["_FillValue"] = -9940 with pytest.raises(ValueError): encoded, _ = xr.conventions.cf_encoder(decoded, decoded.attrs) @requires_dask def test_CFMaskCoder_decode_dask(): - original = xr.Variable(('x',), [0, -1, 1], {'_FillValue': -1}).chunk() - expected = xr.Variable(('x',), [0, np.nan, 1]) + original = xr.Variable(("x",), [0, -1, 1], {"_FillValue": -1}).chunk() + expected = xr.Variable(("x",), [0, np.nan, 1]) coder = variables.CFMaskCoder() encoded = coder.decode(original) assert isinstance(encoded.data, da.Array) assert_identical(expected, encoded) + # TODO(shoyer): port other fill-value tests # TODO(shoyer): parameterize when we have more coders def test_coder_roundtrip(): - original = xr.Variable(('x',), [0.0, np.nan, 1.0]) + original = xr.Variable(("x",), [0.0, np.nan, 1.0]) coder = variables.CFMaskCoder() roundtripped = coder.decode(coder.encode(original)) assert_identical(original, roundtripped) -@pytest.mark.parametrize('dtype', 'u1 u2 i1 i2 f2 f4'.split()) +@pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split()) def test_scaling_converts_to_float32(dtype): - original = xr.Variable(('x',), np.arange(10, dtype=dtype), - encoding=dict(scale_factor=10)) + original = xr.Variable( + ("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=10) + ) coder = variables.CFScaleOffsetCoder() encoded = coder.encode(original) assert encoded.dtype == np.float32 diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 13c0983212e..10cdd03459c 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -9,8 +9,12 @@ from xarray.core import indexing from . import ( - IndexerMaker, assert_array_equal, assert_identical, raises_regex, - requires_dask) + IndexerMaker, + assert_array_equal, + assert_identical, + raises_regex, + requires_dask, +) with suppress(ImportError): import dask.array as da @@ -18,13 +22,13 @@ def test_vlen_dtype(): dtype = strings.create_vlen_dtype(str) - assert dtype.metadata['element_type'] == str + assert dtype.metadata["element_type"] == str assert strings.is_unicode_dtype(dtype) assert not strings.is_bytes_dtype(dtype) assert strings.check_vlen_dtype(dtype) is str dtype = strings.create_vlen_dtype(bytes) - assert dtype.metadata['element_type'] == bytes + assert dtype.metadata["element_type"] == bytes assert not strings.is_unicode_dtype(dtype) assert strings.is_bytes_dtype(dtype) assert strings.check_vlen_dtype(dtype) is bytes @@ -35,12 +39,11 @@ def test_vlen_dtype(): def test_EncodedStringCoder_decode(): coder = strings.EncodedStringCoder() - raw_data = np.array([b'abc', 'ß∂µ∆'.encode()]) - raw = Variable(('x',), raw_data, {'_Encoding': 'utf-8'}) + raw_data = np.array([b"abc", "ß∂µ∆".encode()]) + raw = Variable(("x",), raw_data, {"_Encoding": "utf-8"}) actual = coder.decode(raw) - expected = Variable( - ('x',), np.array(['abc', 'ß∂µ∆'], dtype=object)) + expected = Variable(("x",), np.array(["abc", "ß∂µ∆"], dtype=object)) assert_identical(actual, expected) assert_identical(coder.decode(actual[0]), expected[0]) @@ -50,12 +53,12 @@ def test_EncodedStringCoder_decode(): def test_EncodedStringCoder_decode_dask(): coder = strings.EncodedStringCoder() - raw_data = np.array([b'abc', 'ß∂µ∆'.encode()]) - raw = Variable(('x',), raw_data, {'_Encoding': 'utf-8'}).chunk() + raw_data = np.array([b"abc", "ß∂µ∆".encode()]) + raw = Variable(("x",), raw_data, {"_Encoding": "utf-8"}).chunk() actual = coder.decode(raw) assert isinstance(actual.data, da.Array) - expected = Variable(('x',), np.array(['abc', 'ß∂µ∆'], dtype=object)) + expected = Variable(("x",), np.array(["abc", "ß∂µ∆"], dtype=object)) assert_identical(actual, expected) actual_indexed = coder.decode(actual[0]) @@ -65,70 +68,72 @@ def test_EncodedStringCoder_decode_dask(): def test_EncodedStringCoder_encode(): dtype = strings.create_vlen_dtype(str) - raw_data = np.array(['abc', 'ß∂µ∆'], dtype=dtype) - expected_data = np.array([r.encode('utf-8') for r in raw_data], - dtype=object) + raw_data = np.array(["abc", "ß∂µ∆"], dtype=dtype) + expected_data = np.array([r.encode("utf-8") for r in raw_data], dtype=object) coder = strings.EncodedStringCoder(allows_unicode=True) - raw = Variable(('x',), raw_data, encoding={'dtype': 'S1'}) + raw = Variable(("x",), raw_data, encoding={"dtype": "S1"}) actual = coder.encode(raw) - expected = Variable(('x',), expected_data, attrs={'_Encoding': 'utf-8'}) + expected = Variable(("x",), expected_data, attrs={"_Encoding": "utf-8"}) assert_identical(actual, expected) - raw = Variable(('x',), raw_data) + raw = Variable(("x",), raw_data) assert_identical(coder.encode(raw), raw) coder = strings.EncodedStringCoder(allows_unicode=False) assert_identical(coder.encode(raw), expected) -@pytest.mark.parametrize('original', [ - Variable(('x',), [b'ab', b'cdef']), - Variable((), b'ab'), - Variable(('x',), [b'a', b'b']), - Variable((), b'a'), -]) +@pytest.mark.parametrize( + "original", + [ + Variable(("x",), [b"ab", b"cdef"]), + Variable((), b"ab"), + Variable(("x",), [b"a", b"b"]), + Variable((), b"a"), + ], +) def test_CharacterArrayCoder_roundtrip(original): coder = strings.CharacterArrayCoder() roundtripped = coder.decode(coder.encode(original)) assert_identical(original, roundtripped) -@pytest.mark.parametrize('data', [ - np.array([b'a', b'bc']), - np.array([b'a', b'bc'], dtype=strings.create_vlen_dtype(bytes)), -]) +@pytest.mark.parametrize( + "data", + [ + np.array([b"a", b"bc"]), + np.array([b"a", b"bc"], dtype=strings.create_vlen_dtype(bytes)), + ], +) def test_CharacterArrayCoder_encode(data): coder = strings.CharacterArrayCoder() - raw = Variable(('x',), data) + raw = Variable(("x",), data) actual = coder.encode(raw) - expected = Variable(('x', 'string2'), - np.array([[b'a', b''], [b'b', b'c']])) + expected = Variable(("x", "string2"), np.array([[b"a", b""], [b"b", b"c"]])) assert_identical(actual, expected) @pytest.mark.parametrize( - ['original', 'expected_char_dim_name'], + ["original", "expected_char_dim_name"], [ - (Variable(('x',), [b'ab', b'cdef']), - 'string4'), - (Variable(('x',), [b'ab', b'cdef'], encoding={'char_dim_name': 'foo'}), - 'foo') - ] + (Variable(("x",), [b"ab", b"cdef"]), "string4"), + (Variable(("x",), [b"ab", b"cdef"], encoding={"char_dim_name": "foo"}), "foo"), + ], ) def test_CharacterArrayCoder_char_dim_name(original, expected_char_dim_name): coder = strings.CharacterArrayCoder() encoded = coder.encode(original) roundtripped = coder.decode(encoded) assert encoded.dims[-1] == expected_char_dim_name - assert roundtripped.encoding['char_dim_name'] == expected_char_dim_name + assert roundtripped.encoding["char_dim_name"] == expected_char_dim_name assert roundtripped.dims[-1] == original.dims[-1] def test_StackedBytesArray(): - array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']], dtype='S') + array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") actual = strings.StackedBytesArray(array) - expected = np.array([b'abc', b'def'], dtype='S') + expected = np.array([b"abc", b"def"], dtype="S") assert actual.dtype == expected.dtype assert actual.shape == expected.shape assert actual.size == expected.size @@ -143,10 +148,10 @@ def test_StackedBytesArray(): def test_StackedBytesArray_scalar(): - array = np.array([b'a', b'b', b'c'], dtype='S') + array = np.array([b"a", b"b", b"c"], dtype="S") actual = strings.StackedBytesArray(array) - expected = np.array(b'abc') + expected = np.array(b"abc") assert actual.dtype == expected.dtype assert actual.shape == expected.shape assert actual.size == expected.size @@ -161,9 +166,9 @@ def test_StackedBytesArray_scalar(): def test_StackedBytesArray_vectorized_indexing(): - array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']], dtype='S') + array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") stacked = strings.StackedBytesArray(array) - expected = np.array([[b'abc', b'def'], [b'def', b'abc']]) + expected = np.array([[b"abc", b"def"], [b"def", b"abc"]]) V = IndexerMaker(indexing.VectorizedIndexer) indexer = V[np.array([[0, 1], [1, 0]])] @@ -172,64 +177,62 @@ def test_StackedBytesArray_vectorized_indexing(): def test_char_to_bytes(): - array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']]) - expected = np.array([b'abc', b'def']) + array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]]) + expected = np.array([b"abc", b"def"]) actual = strings.char_to_bytes(array) assert_array_equal(actual, expected) - expected = np.array([b'ad', b'be', b'cf']) + expected = np.array([b"ad", b"be", b"cf"]) actual = strings.char_to_bytes(array.T) # non-contiguous assert_array_equal(actual, expected) def test_char_to_bytes_ndim_zero(): - expected = np.array(b'a') + expected = np.array(b"a") actual = strings.char_to_bytes(expected) assert_array_equal(actual, expected) def test_char_to_bytes_size_zero(): - array = np.zeros((3, 0), dtype='S1') - expected = np.array([b'', b'', b'']) + array = np.zeros((3, 0), dtype="S1") + expected = np.array([b"", b"", b""]) actual = strings.char_to_bytes(array) assert_array_equal(actual, expected) @requires_dask def test_char_to_bytes_dask(): - numpy_array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']]) + numpy_array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]]) array = da.from_array(numpy_array, ((2,), (3,))) - expected = np.array([b'abc', b'def']) + expected = np.array([b"abc", b"def"]) actual = strings.char_to_bytes(array) assert isinstance(actual, da.Array) assert actual.chunks == ((2,),) - assert actual.dtype == 'S3' + assert actual.dtype == "S3" assert_array_equal(np.array(actual), expected) - with raises_regex(ValueError, 'stacked dask character array'): + with raises_regex(ValueError, "stacked dask character array"): strings.char_to_bytes(array.rechunk(1)) def test_bytes_to_char(): - array = np.array([[b'ab', b'cd'], [b'ef', b'gh']]) - expected = np.array([[[b'a', b'b'], [b'c', b'd']], - [[b'e', b'f'], [b'g', b'h']]]) + array = np.array([[b"ab", b"cd"], [b"ef", b"gh"]]) + expected = np.array([[[b"a", b"b"], [b"c", b"d"]], [[b"e", b"f"], [b"g", b"h"]]]) actual = strings.bytes_to_char(array) assert_array_equal(actual, expected) - expected = np.array([[[b'a', b'b'], [b'e', b'f']], - [[b'c', b'd'], [b'g', b'h']]]) + expected = np.array([[[b"a", b"b"], [b"e", b"f"]], [[b"c", b"d"], [b"g", b"h"]]]) actual = strings.bytes_to_char(array.T) # non-contiguous assert_array_equal(actual, expected) @requires_dask def test_bytes_to_char_dask(): - numpy_array = np.array([b'ab', b'cd']) + numpy_array = np.array([b"ab", b"cd"]) array = da.from_array(numpy_array, ((1, 1),)) - expected = np.array([[b'a', b'b'], [b'c', b'd']]) + expected = np.array([[b"a", b"b"], [b"c", b"d"]]) actual = strings.bytes_to_char(array) assert isinstance(actual, da.Array) assert actual.chunks == ((1, 1), ((2,))) - assert actual.dtype == 'S1' + assert actual.dtype == "S1" assert_array_equal(np.array(actual), expected) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 82afeab7aba..ab5ed20d531 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -7,16 +7,26 @@ from xarray import DataArray, Dataset, Variable, coding, decode_cf from xarray.coding.times import ( - _import_cftime, cftime_to_nptime, decode_cf_datetime, encode_cf_datetime, - to_timedelta_unboxed) + _import_cftime, + cftime_to_nptime, + decode_cf_datetime, + encode_cf_datetime, + to_timedelta_unboxed, +) from xarray.coding.variables import SerializationWarning from xarray.conventions import _update_bounds_attributes, cf_encoder from xarray.core.common import contains_cftime_datetimes from xarray.testing import assert_equal from . import ( - assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask, - requires_cftime, requires_cftime_or_netCDF4, arm_xfail) + assert_array_equal, + has_cftime, + has_cftime_or_netCDF4, + has_dask, + requires_cftime, + requires_cftime_or_netCDF4, + arm_xfail, +) try: from pandas.errors import OutOfBoundsDatetime @@ -24,43 +34,53 @@ # pandas < 0.20 from pandas.tslib import OutOfBoundsDatetime -_NON_STANDARD_CALENDARS_SET = {'noleap', '365_day', '360_day', - 'julian', 'all_leap', '366_day'} -_ALL_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET.union( - coding.times._STANDARD_CALENDARS)) +_NON_STANDARD_CALENDARS_SET = { + "noleap", + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", +} +_ALL_CALENDARS = sorted( + _NON_STANDARD_CALENDARS_SET.union(coding.times._STANDARD_CALENDARS) +) _NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) _STANDARD_CALENDARS = sorted(coding.times._STANDARD_CALENDARS) _CF_DATETIME_NUM_DATES_UNITS = [ - (np.arange(10), 'days since 2000-01-01'), - (np.arange(10).astype('float64'), 'days since 2000-01-01'), - (np.arange(10).astype('float32'), 'days since 2000-01-01'), - (np.arange(10).reshape(2, 5), 'days since 2000-01-01'), - (12300 + np.arange(5), 'hours since 1680-01-01 00:00:00'), + (np.arange(10), "days since 2000-01-01"), + (np.arange(10).astype("float64"), "days since 2000-01-01"), + (np.arange(10).astype("float32"), "days since 2000-01-01"), + (np.arange(10).reshape(2, 5), "days since 2000-01-01"), + (12300 + np.arange(5), "hours since 1680-01-01 00:00:00"), # here we add a couple minor formatting errors to test # the robustness of the parsing algorithm. - (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00'), - (12300 + np.arange(5), 'Hour since 1680-01-01 00:00:00'), - (12300 + np.arange(5), ' Hour since 1680-01-01 00:00:00 '), - (10, 'days since 2000-01-01'), - ([10], 'daYs since 2000-01-01'), - ([[10]], 'days since 2000-01-01'), - ([10, 10], 'days since 2000-01-01'), - (np.array(10), 'days since 2000-01-01'), - (0, 'days since 1000-01-01'), - ([0], 'days since 1000-01-01'), - ([[0]], 'days since 1000-01-01'), - (np.arange(2), 'days since 1000-01-01'), - (np.arange(0, 100000, 20000), 'days since 1900-01-01'), - (17093352.0, 'hours since 1-1-1 00:00:0.0'), - ([0.5, 1.5], 'hours since 1900-01-01T00:00:00'), - (0, 'milliseconds since 2000-01-01T00:00:00'), - (0, 'microseconds since 2000-01-01T00:00:00'), - (np.int32(788961600), 'seconds since 1981-01-01'), # GH2002 - (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00.500000') + (12300 + np.arange(5), "hour since 1680-01-01 00:00:00"), + (12300 + np.arange(5), "Hour since 1680-01-01 00:00:00"), + (12300 + np.arange(5), " Hour since 1680-01-01 00:00:00 "), + (10, "days since 2000-01-01"), + ([10], "daYs since 2000-01-01"), + ([[10]], "days since 2000-01-01"), + ([10, 10], "days since 2000-01-01"), + (np.array(10), "days since 2000-01-01"), + (0, "days since 1000-01-01"), + ([0], "days since 1000-01-01"), + ([[0]], "days since 1000-01-01"), + (np.arange(2), "days since 1000-01-01"), + (np.arange(0, 100000, 20000), "days since 1900-01-01"), + (17093352.0, "hours since 1-1-1 00:00:0.0"), + ([0.5, 1.5], "hours since 1900-01-01T00:00:00"), + (0, "milliseconds since 2000-01-01T00:00:00"), + (0, "microseconds since 2000-01-01T00:00:00"), + (np.int32(788961600), "seconds since 1981-01-01"), # GH2002 + (12300 + np.arange(5), "hour since 1680-01-01 00:00:00.500000"), +] +_CF_DATETIME_TESTS = [ + num_dates_units + (calendar,) + for num_dates_units, calendar in product( + _CF_DATETIME_NUM_DATES_UNITS, _STANDARD_CALENDARS + ) ] -_CF_DATETIME_TESTS = [num_dates_units + (calendar,) for num_dates_units, - calendar in product(_CF_DATETIME_NUM_DATES_UNITS, - _STANDARD_CALENDARS)] def _all_cftime_date_types(): @@ -68,24 +88,26 @@ def _all_cftime_date_types(): import cftime except ImportError: import netcdftime as cftime - return {'noleap': cftime.DatetimeNoLeap, - '365_day': cftime.DatetimeNoLeap, - '360_day': cftime.Datetime360Day, - 'julian': cftime.DatetimeJulian, - 'all_leap': cftime.DatetimeAllLeap, - '366_day': cftime.DatetimeAllLeap, - 'gregorian': cftime.DatetimeGregorian, - 'proleptic_gregorian': cftime.DatetimeProlepticGregorian} - - -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize(['num_dates', 'units', 'calendar'], - _CF_DATETIME_TESTS) + return { + "noleap": cftime.DatetimeNoLeap, + "365_day": cftime.DatetimeNoLeap, + "360_day": cftime.Datetime360Day, + "julian": cftime.DatetimeJulian, + "all_leap": cftime.DatetimeAllLeap, + "366_day": cftime.DatetimeAllLeap, + "gregorian": cftime.DatetimeGregorian, + "proleptic_gregorian": cftime.DatetimeProlepticGregorian, + } + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize(["num_dates", "units", "calendar"], _CF_DATETIME_TESTS) def test_cf_datetime(num_dates, units, calendar): cftime = _import_cftime() - if cftime.__name__ == 'cftime': - expected = cftime.num2date(num_dates, units, calendar, - only_use_cftime_datetimes=True) + if cftime.__name__ == "cftime": + expected = cftime.num2date( + num_dates, units, calendar, only_use_cftime_datetimes=True + ) else: expected = cftime.num2date(num_dates, units, calendar) min_y = np.ravel(np.atleast_1d(expected))[np.nanargmin(num_dates)].year @@ -94,31 +116,28 @@ def test_cf_datetime(num_dates, units, calendar): expected = cftime_to_nptime(expected) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(num_dates, units, - calendar) + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(num_dates, units, calendar) abs_diff = np.atleast_1d(abs(actual - expected)).astype(np.timedelta64) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() - encoded, _, _ = coding.times.encode_cf_datetime(actual, units, - calendar) - if '1-1-1' not in units: + assert (abs_diff <= np.timedelta64(1, "s")).all() + encoded, _, _ = coding.times.encode_cf_datetime(actual, units, calendar) + if "1-1-1" not in units: # pandas parses this date very strangely, so the original # units/encoding cannot be preserved in this case: # (Pdb) pd.to_datetime('1-1-1 00:00:0.0') # Timestamp('2001-01-01 00:00:00') assert_array_equal(num_dates, np.around(encoded, 1)) - if (hasattr(num_dates, 'ndim') and num_dates.ndim == 1 and - '1000' not in units): + if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: # verify that wrapping with a pandas.Index works # note that it *does not* currently work to even put # non-datetime64 compatible dates into a pandas.Index encoded, _, _ = coding.times.encode_cf_datetime( - pd.Index(actual), units, calendar) + pd.Index(actual), units, calendar + ) assert_array_equal(num_dates, np.around(encoded, 1)) @@ -133,7 +152,7 @@ def test_decode_cf_datetime_overflow(): from netcdftime import DatetimeGregorian datetime = DatetimeGregorian - units = 'days since 2000-01-01 00:00:00' + units = "days since 2000-01-01 00:00:00" # date after 2262 and before 1678 days = (-117608, 95795) @@ -141,16 +160,16 @@ def test_decode_cf_datetime_overflow(): for i, day in enumerate(days): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') + warnings.filterwarnings("ignore", "Unable to decode time axis") result = coding.times.decode_cf_datetime(day, units) assert result == expected[i] def test_decode_cf_datetime_non_standard_units(): - expected = pd.date_range(periods=100, start='1970-01-01', freq='h') + expected = pd.date_range(periods=100, start="1970-01-01", freq="h") # netCDFs from madis.noaa.gov use this format for their time units # they cannot be parsed by cftime, but pd.Timestamp works - units = 'hours since 1-1-1970' + units = "hours since 1-1-1970" actual = coding.times.decode_cf_datetime(np.arange(100), units) assert_array_equal(actual, expected) @@ -159,29 +178,31 @@ def test_decode_cf_datetime_non_standard_units(): def test_decode_cf_datetime_non_iso_strings(): # datetime strings that are _almost_ ISO compliant but not quite, # but which cftime.num2date can still parse correctly - expected = pd.date_range(periods=100, start='2000-01-01', freq='h') - cases = [(np.arange(100), 'hours since 2000-01-01 0'), - (np.arange(100), 'hours since 2000-1-1 0'), - (np.arange(100), 'hours since 2000-01-01 0:00')] + expected = pd.date_range(periods=100, start="2000-01-01", freq="h") + cases = [ + (np.arange(100), "hours since 2000-01-01 0"), + (np.arange(100), "hours since 2000-1-1 0"), + (np.arange(100), "hours since 2000-01-01 0:00"), + ] for num_dates, units in cases: actual = coding.times.decode_cf_datetime(num_dates, units) abs_diff = abs(actual - expected.values) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (abs_diff <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) def test_decode_standard_calendar_inside_timestamp_range(calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' - times = pd.date_range('2001-04-01-00', end='2001-04-30-23', freq='H') + units = "days since 0001-01-01" + times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="H") time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) expected = times.values - expected_dtype = np.dtype('M8[ns]') + expected_dtype = np.dtype("M8[ns]") actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) assert actual.dtype == expected_dtype @@ -189,184 +210,170 @@ def test_decode_standard_calendar_inside_timestamp_range(calendar): # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (abs_diff <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) -def test_decode_non_standard_calendar_inside_timestamp_range( - calendar): +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_non_standard_calendar_inside_timestamp_range(calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' - times = pd.date_range('2001-04-01-00', end='2001-04-30-23', - freq='H') - non_standard_time = cftime.date2num( - times.to_pydatetime(), units, calendar=calendar) + units = "days since 0001-01-01" + times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="H") + non_standard_time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) - if cftime.__name__ == 'cftime': + if cftime.__name__ == "cftime": expected = cftime.num2date( - non_standard_time, units, calendar=calendar, - only_use_cftime_datetimes=True) + non_standard_time, units, calendar=calendar, only_use_cftime_datetimes=True + ) else: - expected = cftime.num2date(non_standard_time, units, - calendar=calendar) + expected = cftime.num2date(non_standard_time, units, calendar=calendar) - expected_dtype = np.dtype('O') + expected_dtype = np.dtype("O") actual = coding.times.decode_cf_datetime( - non_standard_time, units, calendar=calendar) + non_standard_time, units, calendar=calendar + ) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (abs_diff <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_decode_dates_outside_timestamp_range(calendar): from datetime import datetime + cftime = _import_cftime() - units = 'days since 0001-01-01' + units = "days since 0001-01-01" times = [datetime(1, 4, 1, h) for h in range(1, 5)] time = cftime.date2num(times, units, calendar=calendar) - if cftime.__name__ == 'cftime': - expected = cftime.num2date(time, units, calendar=calendar, - only_use_cftime_datetimes=True) + if cftime.__name__ == "cftime": + expected = cftime.num2date( + time, units, calendar=calendar, only_use_cftime_datetimes=True + ) else: expected = cftime.num2date(time, units, calendar=calendar) expected_date_type = type(expected[0]) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime( - time, units, calendar=calendar) + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) assert all(isinstance(value, expected_date_type) for value in actual) abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (abs_diff <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) -def test_decode_standard_calendar_single_element_inside_timestamp_range( - calendar): - units = 'days since 0001-01-01' +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_decode_standard_calendar_single_element_inside_timestamp_range(calendar): + units = "days since 0001-01-01" for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar) - assert actual.dtype == np.dtype('M8[ns]') + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + assert actual.dtype == np.dtype("M8[ns]") -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) -def test_decode_non_standard_calendar_single_element_inside_timestamp_range( - calendar): - units = 'days since 0001-01-01' +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_non_standard_calendar_single_element_inside_timestamp_range(calendar): + units = "days since 0001-01-01" for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar) - assert actual.dtype == np.dtype('O') + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + assert actual.dtype == np.dtype("O") -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) -def test_decode_single_element_outside_timestamp_range( - calendar): +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_single_element_outside_timestamp_range(calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' + units = "days since 0001-01-01" for days in [1, 1470376]: for num_time in [days, [days], [[days]]]: with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') + warnings.filterwarnings("ignore", "Unable to decode time axis") actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar) + num_time, units, calendar=calendar + ) - if cftime.__name__ == 'cftime': - expected = cftime.num2date(days, units, calendar, - only_use_cftime_datetimes=True) + if cftime.__name__ == "cftime": + expected = cftime.num2date( + days, units, calendar, only_use_cftime_datetimes=True + ) else: expected = cftime.num2date(days, units, calendar) assert isinstance(actual.item(), type(expected)) -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) -def test_decode_standard_calendar_multidim_time_inside_timestamp_range( - calendar): +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_decode_standard_calendar_multidim_time_inside_timestamp_range(calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' - times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') - times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') - time1 = cftime.date2num(times1.to_pydatetime(), - units, calendar=calendar) - time2 = cftime.date2num(times2.to_pydatetime(), - units, calendar=calendar) - mdim_time = np.empty((len(time1), 2), ) + units = "days since 0001-01-01" + times1 = pd.date_range("2001-04-01", end="2001-04-05", freq="D") + times2 = pd.date_range("2001-05-01", end="2001-05-05", freq="D") + time1 = cftime.date2num(times1.to_pydatetime(), units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), units, calendar=calendar) + mdim_time = np.empty((len(time1), 2)) mdim_time[:, 0] = time1 mdim_time[:, 1] = time2 expected1 = times1.values expected2 = times2.values - actual = coding.times.decode_cf_datetime( - mdim_time, units, calendar=calendar) - assert actual.dtype == np.dtype('M8[ns]') + actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + assert actual.dtype == np.dtype("M8[ns]") abs_diff1 = abs(actual[:, 0] - expected1) abs_diff2 = abs(actual[:, 1] - expected2) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff1 <= np.timedelta64(1, 's')).all() - assert (abs_diff2 <= np.timedelta64(1, 's')).all() + assert (abs_diff1 <= np.timedelta64(1, "s")).all() + assert (abs_diff2 <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) -def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( - calendar): +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range(calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' - times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') - times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') - time1 = cftime.date2num(times1.to_pydatetime(), - units, calendar=calendar) - time2 = cftime.date2num(times2.to_pydatetime(), - units, calendar=calendar) - mdim_time = np.empty((len(time1), 2), ) + units = "days since 0001-01-01" + times1 = pd.date_range("2001-04-01", end="2001-04-05", freq="D") + times2 = pd.date_range("2001-05-01", end="2001-05-05", freq="D") + time1 = cftime.date2num(times1.to_pydatetime(), units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), units, calendar=calendar) + mdim_time = np.empty((len(time1), 2)) mdim_time[:, 0] = time1 mdim_time[:, 1] = time2 - if cftime.__name__ == 'cftime': - expected1 = cftime.num2date(time1, units, calendar, - only_use_cftime_datetimes=True) - expected2 = cftime.num2date(time2, units, calendar, - only_use_cftime_datetimes=True) + if cftime.__name__ == "cftime": + expected1 = cftime.num2date( + time1, units, calendar, only_use_cftime_datetimes=True + ) + expected2 = cftime.num2date( + time2, units, calendar, only_use_cftime_datetimes=True + ) else: expected1 = cftime.num2date(time1, units, calendar) expected2 = cftime.num2date(time2, units, calendar) - expected_dtype = np.dtype('O') + expected_dtype = np.dtype("O") - actual = coding.times.decode_cf_datetime( - mdim_time, units, calendar=calendar) + actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff1 = abs(actual[:, 0] - expected1) @@ -374,57 +381,57 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff1 <= np.timedelta64(1, 's')).all() - assert (abs_diff2 <= np.timedelta64(1, 's')).all() + assert (abs_diff1 <= np.timedelta64(1, "s")).all() + assert (abs_diff2 <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) -def test_decode_multidim_time_outside_timestamp_range( - calendar): +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_decode_multidim_time_outside_timestamp_range(calendar): from datetime import datetime + cftime = _import_cftime() - units = 'days since 0001-01-01' + units = "days since 0001-01-01" times1 = [datetime(1, 4, day) for day in range(1, 6)] times2 = [datetime(1, 5, day) for day in range(1, 6)] time1 = cftime.date2num(times1, units, calendar=calendar) time2 = cftime.date2num(times2, units, calendar=calendar) - mdim_time = np.empty((len(time1), 2), ) + mdim_time = np.empty((len(time1), 2)) mdim_time[:, 0] = time1 mdim_time[:, 1] = time2 - if cftime.__name__ == 'cftime': - expected1 = cftime.num2date(time1, units, calendar, - only_use_cftime_datetimes=True) - expected2 = cftime.num2date(time2, units, calendar, - only_use_cftime_datetimes=True) + if cftime.__name__ == "cftime": + expected1 = cftime.num2date( + time1, units, calendar, only_use_cftime_datetimes=True + ) + expected2 = cftime.num2date( + time2, units, calendar, only_use_cftime_datetimes=True + ) else: expected1 = cftime.num2date(time1, units, calendar) expected2 = cftime.num2date(time2, units, calendar) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime( - mdim_time, units, calendar=calendar) + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) - assert actual.dtype == np.dtype('O') + assert actual.dtype == np.dtype("O") abs_diff1 = abs(actual[:, 0] - expected1) abs_diff2 = abs(actual[:, 1] - expected2) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff1 <= np.timedelta64(1, 's')).all() - assert (abs_diff2 <= np.timedelta64(1, 's')).all() + assert (abs_diff1 <= np.timedelta64(1, "s")).all() + assert (abs_diff2 <= np.timedelta64(1, "s")).all() -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('calendar', ['360_day', 'all_leap', '366_day']) -def test_decode_non_standard_calendar_single_element( - calendar): +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("calendar", ["360_day", "all_leap", "366_day"]) +def test_decode_non_standard_calendar_single_element(calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' + units = "days since 0001-01-01" try: dt = cftime.netcdftime.datetime(2001, 2, 29) @@ -433,55 +440,62 @@ def test_decode_non_standard_calendar_single_element( dt = cftime.datetime(2001, 2, 29) num_time = cftime.date2num(dt, units, calendar) - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar) + actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) - if cftime.__name__ == 'cftime': - expected = np.asarray(cftime.num2date( - num_time, units, calendar, only_use_cftime_datetimes=True)) + if cftime.__name__ == "cftime": + expected = np.asarray( + cftime.num2date(num_time, units, calendar, only_use_cftime_datetimes=True) + ) else: expected = np.asarray(cftime.num2date(num_time, units, calendar)) - assert actual.dtype == np.dtype('O') + assert actual.dtype == np.dtype("O") assert expected == actual -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") def test_decode_360_day_calendar(): cftime = _import_cftime() - calendar = '360_day' + calendar = "360_day" # ensure leap year doesn't matter for year in [2010, 2011, 2012, 2013, 2014]: - units = 'days since {}-01-01'.format(year) + units = "days since {}-01-01".format(year) num_times = np.arange(100) - if cftime.__name__ == 'cftime': - expected = cftime.num2date(num_times, units, calendar, - only_use_cftime_datetimes=True) + if cftime.__name__ == "cftime": + expected = cftime.num2date( + num_times, units, calendar, only_use_cftime_datetimes=True + ) else: expected = cftime.num2date(num_times, units, calendar) with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") actual = coding.times.decode_cf_datetime( - num_times, units, calendar=calendar) + num_times, units, calendar=calendar + ) assert len(w) == 0 - assert actual.dtype == np.dtype('O') + assert actual.dtype == np.dtype("O") assert_array_equal(actual, expected) @arm_xfail -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") @pytest.mark.parametrize( - ['num_dates', 'units', 'expected_list'], - [([np.nan], 'days since 2000-01-01', ['NaT']), - ([np.nan, 0], 'days since 2000-01-01', - ['NaT', '2000-01-01T00:00:00Z']), - ([np.nan, 0, 1], 'days since 2000-01-01', - ['NaT', '2000-01-01T00:00:00Z', '2000-01-02T00:00:00Z'])]) + ["num_dates", "units", "expected_list"], + [ + ([np.nan], "days since 2000-01-01", ["NaT"]), + ([np.nan, 0], "days since 2000-01-01", ["NaT", "2000-01-01T00:00:00Z"]), + ( + [np.nan, 0, 1], + "days since 2000-01-01", + ["NaT", "2000-01-01T00:00:00Z", "2000-01-02T00:00:00Z"], + ), + ], +) def test_cf_datetime_nan(num_dates, units, expected_list): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'All-NaN') + warnings.filterwarnings("ignore", "All-NaN") actual = coding.times.decode_cf_datetime(num_dates, units) # use pandas because numpy will deprecate timezone-aware conversions expected = pd.to_datetime(expected_list) @@ -491,49 +505,57 @@ def test_cf_datetime_nan(num_dates, units, expected_list): @requires_cftime_or_netCDF4 def test_decoded_cf_datetime_array_2d(): # regression test for GH1229 - variable = Variable(('x', 'y'), np.array([[0, 1], [2, 3]]), - {'units': 'days since 2000-01-01'}) + variable = Variable( + ("x", "y"), np.array([[0, 1], [2, 3]]), {"units": "days since 2000-01-01"} + ) result = coding.times.CFDatetimeCoder().decode(variable) - assert result.dtype == 'datetime64[ns]' - expected = pd.date_range('2000-01-01', periods=4).values.reshape(2, 2) + assert result.dtype == "datetime64[ns]" + expected = pd.date_range("2000-01-01", periods=4).values.reshape(2, 2) assert_array_equal(np.asarray(result), expected) @pytest.mark.parametrize( - ['dates', 'expected'], - [(pd.date_range('1900-01-01', periods=5), - 'days since 1900-01-01 00:00:00'), - (pd.date_range('1900-01-01 12:00:00', freq='H', - periods=2), - 'hours since 1900-01-01 12:00:00'), - (pd.to_datetime( - ['1900-01-01', '1900-01-02', 'NaT']), - 'days since 1900-01-01 00:00:00'), - (pd.to_datetime(['1900-01-01', - '1900-01-02T00:00:00.005']), - 'seconds since 1900-01-01 00:00:00'), - (pd.to_datetime(['NaT', '1900-01-01']), - 'days since 1900-01-01 00:00:00'), - (pd.to_datetime(['NaT']), - 'days since 1970-01-01 00:00:00')]) + ["dates", "expected"], + [ + (pd.date_range("1900-01-01", periods=5), "days since 1900-01-01 00:00:00"), + ( + pd.date_range("1900-01-01 12:00:00", freq="H", periods=2), + "hours since 1900-01-01 12:00:00", + ), + ( + pd.to_datetime(["1900-01-01", "1900-01-02", "NaT"]), + "days since 1900-01-01 00:00:00", + ), + ( + pd.to_datetime(["1900-01-01", "1900-01-02T00:00:00.005"]), + "seconds since 1900-01-01 00:00:00", + ), + (pd.to_datetime(["NaT", "1900-01-01"]), "days since 1900-01-01 00:00:00"), + (pd.to_datetime(["NaT"]), "days since 1970-01-01 00:00:00"), + ], +) def test_infer_datetime_units(dates, expected): assert expected == coding.times.infer_datetime_units(dates) _CFTIME_DATETIME_UNITS_TESTS = [ - ([(1900, 1, 1), (1900, 1, 1)], 'days since 1900-01-01 00:00:00.000000'), - ([(1900, 1, 1), (1900, 1, 2), (1900, 1, 2, 0, 0, 1)], - 'seconds since 1900-01-01 00:00:00.000000'), - ([(1900, 1, 1), (1900, 1, 8), (1900, 1, 16)], - 'days since 1900-01-01 00:00:00.000000') + ([(1900, 1, 1), (1900, 1, 1)], "days since 1900-01-01 00:00:00.000000"), + ( + [(1900, 1, 1), (1900, 1, 2), (1900, 1, 2, 0, 0, 1)], + "seconds since 1900-01-01 00:00:00.000000", + ), + ( + [(1900, 1, 1), (1900, 1, 8), (1900, 1, 16)], + "days since 1900-01-01 00:00:00.000000", + ), ] -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") @pytest.mark.parametrize( - 'calendar', _NON_STANDARD_CALENDARS + ['gregorian', 'proleptic_gregorian']) -@pytest.mark.parametrize(('date_args', 'expected'), - _CFTIME_DATETIME_UNITS_TESTS) + "calendar", _NON_STANDARD_CALENDARS + ["gregorian", "proleptic_gregorian"] +) +@pytest.mark.parametrize(("date_args", "expected"), _CFTIME_DATETIME_UNITS_TESTS) def test_infer_cftime_datetime_units(calendar, date_args, expected): date_type = _all_cftime_date_types()[calendar] dates = [date_type(*args) for args in date_args] @@ -541,21 +563,22 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected): @pytest.mark.parametrize( - ['timedeltas', 'units', 'numbers'], + ["timedeltas", "units", "numbers"], [ - ('1D', 'days', np.int64(1)), - (['1D', '2D', '3D'], 'days', np.array([1, 2, 3], 'int64')), - ('1h', 'hours', np.int64(1)), - ('1ms', 'milliseconds', np.int64(1)), - ('1us', 'microseconds', np.int64(1)), - (['NaT', '0s', '1s'], None, [np.nan, 0, 1]), - (['30m', '60m'], 'hours', [0.5, 1.0]), - ('NaT', 'days', np.nan), - (['NaT', 'NaT'], 'days', [np.nan, np.nan]), - ]) + ("1D", "days", np.int64(1)), + (["1D", "2D", "3D"], "days", np.array([1, 2, 3], "int64")), + ("1h", "hours", np.int64(1)), + ("1ms", "milliseconds", np.int64(1)), + ("1us", "microseconds", np.int64(1)), + (["NaT", "0s", "1s"], None, [np.nan, 0, 1]), + (["30m", "60m"], "hours", [0.5, 1.0]), + ("NaT", "days", np.nan), + (["NaT", "NaT"], "days", [np.nan, np.nan]), + ], +) def test_cf_timedelta(timedeltas, units, numbers): - if timedeltas == 'NaT': - timedeltas = np.timedelta64('NaT', 'ns') + if timedeltas == "NaT": + timedeltas = np.timedelta64("NaT", "ns") else: timedeltas = to_timedelta_unboxed(timedeltas) numbers = np.array(numbers) @@ -571,14 +594,14 @@ def test_cf_timedelta(timedeltas, units, numbers): assert_array_equal(expected, actual) assert expected.dtype == actual.dtype - expected = np.timedelta64('NaT', 'ns') - actual = coding.times.decode_cf_timedelta(np.array(np.nan), 'days') + expected = np.timedelta64("NaT", "ns") + actual = coding.times.decode_cf_timedelta(np.array(np.nan), "days") assert_array_equal(expected, actual) def test_cf_timedelta_2d(): - timedeltas = ['1D', '2D', '3D'] - units = 'days' + timedeltas = ["1D", "2D", "3D"] + units = "days" numbers = np.atleast_2d([1, 2, 3]) timedeltas = np.atleast_2d(to_timedelta_unboxed(timedeltas)) @@ -590,25 +613,28 @@ def test_cf_timedelta_2d(): @pytest.mark.parametrize( - ['deltas', 'expected'], - [(pd.to_timedelta(['1 day', '2 days']), 'days'), - (pd.to_timedelta(['1h', '1 day 1 hour']), 'hours'), - (pd.to_timedelta(['1m', '2m', np.nan]), 'minutes'), - (pd.to_timedelta(['1m3s', '1m4s']), 'seconds')]) + ["deltas", "expected"], + [ + (pd.to_timedelta(["1 day", "2 days"]), "days"), + (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), + (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), + (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), + ], +) def test_infer_timedelta_units(deltas, expected): assert expected == coding.times.infer_timedelta_units(deltas) -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize(['date_args', 'expected'], - [((1, 2, 3, 4, 5, 6), - '0001-02-03 04:05:06.000000'), - ((10, 2, 3, 4, 5, 6), - '0010-02-03 04:05:06.000000'), - ((100, 2, 3, 4, 5, 6), - '0100-02-03 04:05:06.000000'), - ((1000, 2, 3, 4, 5, 6), - '1000-02-03 04:05:06.000000')]) +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize( + ["date_args", "expected"], + [ + ((1, 2, 3, 4, 5, 6), "0001-02-03 04:05:06.000000"), + ((10, 2, 3, 4, 5, 6), "0010-02-03 04:05:06.000000"), + ((100, 2, 3, 4, 5, 6), "0100-02-03 04:05:06.000000"), + ((1000, 2, 3, 4, 5, 6), "1000-02-03 04:05:06.000000"), + ], +) def test_format_cftime_datetime(date_args, expected): date_types = _all_cftime_date_types() for date_type in date_types.values(): @@ -616,15 +642,15 @@ def test_format_cftime_datetime(date_args, expected): assert result == expected -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_decode_cf(calendar): - days = [1., 2., 3.] - da = DataArray(days, coords=[days], dims=['time'], name='test') + days = [1.0, 2.0, 3.0] + da = DataArray(days, coords=[days], dims=["time"], name="test") ds = da.to_dataset() - for v in ['test', 'time']: - ds[v].attrs['units'] = 'days since 2001-01-01' - ds[v].attrs['calendar'] = calendar + for v in ["test", "time"]: + ds[v].attrs["units"] = "days since 2001-01-01" + ds[v].attrs["calendar"] = calendar if not has_cftime_or_netCDF4 and calendar not in _STANDARD_CALENDARS: with pytest.raises(ValueError): @@ -633,85 +659,88 @@ def test_decode_cf(calendar): ds = decode_cf(ds) if calendar not in _STANDARD_CALENDARS: - assert ds.test.dtype == np.dtype('O') + assert ds.test.dtype == np.dtype("O") else: - assert ds.test.dtype == np.dtype('M8[ns]') + assert ds.test.dtype == np.dtype("M8[ns]") def test_decode_cf_time_bounds(): - da = DataArray(np.arange(6, dtype='int64').reshape((3, 2)), - coords={'time': [1, 2, 3]}, - dims=('time', 'nbnd'), name='time_bnds') + da = DataArray( + np.arange(6, dtype="int64").reshape((3, 2)), + coords={"time": [1, 2, 3]}, + dims=("time", "nbnd"), + name="time_bnds", + ) - attrs = {'units': 'days since 2001-01', - 'calendar': 'standard', - 'bounds': 'time_bnds'} + attrs = { + "units": "days since 2001-01", + "calendar": "standard", + "bounds": "time_bnds", + } ds = da.to_dataset() - ds['time'].attrs.update(attrs) + ds["time"].attrs.update(attrs) _update_bounds_attributes(ds.variables) - assert ds.variables['time_bnds'].attrs == {'units': 'days since 2001-01', - 'calendar': 'standard'} + assert ds.variables["time_bnds"].attrs == { + "units": "days since 2001-01", + "calendar": "standard", + } dsc = decode_cf(ds) - assert dsc.time_bnds.dtype == np.dtype('M8[ns]') + assert dsc.time_bnds.dtype == np.dtype("M8[ns]") dsc = decode_cf(ds, decode_times=False) - assert dsc.time_bnds.dtype == np.dtype('int64') + assert dsc.time_bnds.dtype == np.dtype("int64") # Do not overwrite existing attrs ds = da.to_dataset() - ds['time'].attrs.update(attrs) - bnd_attr = {'units': 'hours since 2001-01', 'calendar': 'noleap'} - ds['time_bnds'].attrs.update(bnd_attr) + ds["time"].attrs.update(attrs) + bnd_attr = {"units": "hours since 2001-01", "calendar": "noleap"} + ds["time_bnds"].attrs.update(bnd_attr) _update_bounds_attributes(ds.variables) - assert ds.variables['time_bnds'].attrs == bnd_attr + assert ds.variables["time_bnds"].attrs == bnd_attr # If bounds variable not available do not complain ds = da.to_dataset() - ds['time'].attrs.update(attrs) - ds['time'].attrs['bounds'] = 'fake_var' + ds["time"].attrs.update(attrs) + ds["time"].attrs["bounds"] = "fake_var" _update_bounds_attributes(ds.variables) @requires_cftime_or_netCDF4 def test_encode_time_bounds(): - time = pd.date_range('2000-01-16', periods=1) - time_bounds = pd.date_range('2000-01-01', periods=2, freq='MS') + time = pd.date_range("2000-01-16", periods=1) + time_bounds = pd.date_range("2000-01-01", periods=2, freq="MS") ds = Dataset(dict(time=time, time_bounds=time_bounds)) - ds.time.attrs = {'bounds': 'time_bounds'} - ds.time.encoding = {'calendar': 'noleap', - 'units': 'days since 2000-01-01'} + ds.time.attrs = {"bounds": "time_bounds"} + ds.time.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"} expected = dict() # expected['time'] = Variable(data=np.array([15]), dims=['time']) - expected['time_bounds'] = Variable(data=np.array([0, 31]), - dims=['time_bounds']) + expected["time_bounds"] = Variable(data=np.array([0, 31]), dims=["time_bounds"]) encoded, _ = cf_encoder(ds.variables, ds.attrs) - assert_equal(encoded['time_bounds'], expected['time_bounds']) - assert 'calendar' not in encoded['time_bounds'].attrs - assert 'units' not in encoded['time_bounds'].attrs + assert_equal(encoded["time_bounds"], expected["time_bounds"]) + assert "calendar" not in encoded["time_bounds"].attrs + assert "units" not in encoded["time_bounds"].attrs # if time_bounds attrs are same as time attrs, it doesn't matter - ds.time_bounds.encoding = {'calendar': 'noleap', - 'units': 'days since 2000-01-01'} - encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, - ds.attrs) - assert_equal(encoded['time_bounds'], expected['time_bounds']) - assert 'calendar' not in encoded['time_bounds'].attrs - assert 'units' not in encoded['time_bounds'].attrs + ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"} + encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs) + assert_equal(encoded["time_bounds"], expected["time_bounds"]) + assert "calendar" not in encoded["time_bounds"].attrs + assert "units" not in encoded["time_bounds"].attrs # for CF-noncompliant case of time_bounds attrs being different from # time attrs; preserve them for faithful roundtrip - ds.time_bounds.encoding = {'calendar': 'noleap', - 'units': 'days since 1849-01-01'} - encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, - ds.attrs) + ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 1849-01-01"} + encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs) with pytest.raises(AssertionError): - assert_equal(encoded['time_bounds'], expected['time_bounds']) - assert 'calendar' not in encoded['time_bounds'].attrs - assert encoded['time_bounds'].attrs['units'] == ds.time_bounds.encoding['units'] # noqa + assert_equal(encoded["time_bounds"], expected["time_bounds"]) + assert "calendar" not in encoded["time_bounds"].attrs + assert ( + encoded["time_bounds"].attrs["units"] == ds.time_bounds.encoding["units"] + ) # noqa ds.time.encoding = {} with pytest.warns(UserWarning): @@ -728,8 +757,11 @@ def times(calendar): cftime = _import_cftime() return cftime.num2date( - np.arange(4), units='hours since 2000-01-01', calendar=calendar, - only_use_cftime_datetimes=True) + np.arange(4), + units="hours since 2000-01-01", + calendar=calendar, + only_use_cftime_datetimes=True, + ) @pytest.fixture() @@ -737,8 +769,9 @@ def data(times): data = np.random.rand(2, 2, 4) lons = np.linspace(0, 11, 2) lats = np.linspace(0, 20, 2) - return DataArray(data, coords=[lons, lats, times], - dims=['lon', 'lat', 'time'], name='data') + return DataArray( + data, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) @pytest.fixture() @@ -746,51 +779,51 @@ def times_3d(times): lons = np.linspace(0, 11, 2) lats = np.linspace(0, 20, 2) times_arr = np.random.choice(times, size=(2, 2, 4)) - return DataArray(times_arr, coords=[lons, lats, times], - dims=['lon', 'lat', 'time'], - name='data') + return DataArray( + times_arr, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_contains_cftime_datetimes_1d(data): assert contains_cftime_datetimes(data.time) -@pytest.mark.skipif(not has_dask, reason='dask not installed') -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_dask, reason="dask not installed") +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_contains_cftime_datetimes_dask_1d(data): assert contains_cftime_datetimes(data.time.chunk()) -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_contains_cftime_datetimes_3d(times_3d): assert contains_cftime_datetimes(times_3d) -@pytest.mark.skipif(not has_dask, reason='dask not installed') -@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.skipif(not has_dask, reason="dask not installed") +@pytest.mark.skipif(not has_cftime, reason="cftime not installed") def test_contains_cftime_datetimes_dask_3d(times_3d): assert contains_cftime_datetimes(times_3d.chunk()) -@pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) +@pytest.mark.parametrize("non_cftime_data", [DataArray([]), DataArray([1, 2])]) def test_contains_cftime_datetimes_non_cftimes(non_cftime_data): assert not contains_cftime_datetimes(non_cftime_data) -@pytest.mark.skipif(not has_dask, reason='dask not installed') -@pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) +@pytest.mark.skipif(not has_dask, reason="dask not installed") +@pytest.mark.parametrize("non_cftime_data", [DataArray([]), DataArray([1, 2])]) def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data): assert not contains_cftime_datetimes(non_cftime_data.chunk()) -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('shape', [(24,), (8, 3), (2, 4, 3)]) +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") +@pytest.mark.parametrize("shape", [(24,), (8, 3), (2, 4, 3)]) def test_encode_cf_datetime_overflow(shape): # Test for fix to GH 2272 - dates = pd.date_range('2100', periods=24).values.reshape(shape) - units = 'days since 1800-01-01' - calendar = 'standard' + dates = pd.date_range("2100", periods=24).values.reshape(shape) + units = "days since 1800-01-01" + calendar = "standard" num, _, _ = encode_cf_datetime(dates, units, calendar) roundtrip = decode_cf_datetime(num, units, calendar) @@ -800,20 +833,20 @@ def test_encode_cf_datetime_overflow(shape): def test_encode_cf_datetime_pandas_min(): # Test that encode_cf_datetime does not fail for versions # of pandas < 0.21.1 (GH 2623). - dates = pd.date_range('2000', periods=3) + dates = pd.date_range("2000", periods=3) num, units, calendar = encode_cf_datetime(dates) - expected_num = np.array([0., 1., 2.]) - expected_units = 'days since 2000-01-01 00:00:00' - expected_calendar = 'proleptic_gregorian' + expected_num = np.array([0.0, 1.0, 2.0]) + expected_units = "days since 2000-01-01 00:00:00" + expected_calendar = "proleptic_gregorian" np.testing.assert_array_equal(num, expected_num) assert units == expected_units assert calendar == expected_calendar -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") def test_time_units_with_timezone_roundtrip(calendar): # Regression test for GH 2649 - expected_units = 'days since 2000-01-01T00:00:00-05:00' + expected_units = "days since 2000-01-01T00:00:00-05:00" expected_num_dates = np.array([1, 2, 3]) dates = decode_cf_datetime(expected_num_dates, expected_units, calendar) @@ -825,7 +858,8 @@ def test_time_units_with_timezone_roundtrip(calendar): # Check that the encoded values are accurately roundtripped. result_num_dates, result_units, result_calendar = encode_cf_datetime( - dates, expected_units, calendar) + dates, expected_units, calendar + ) if calendar in _STANDARD_CALENDARS: np.testing.assert_array_equal(result_num_dates, expected_num_dates) @@ -837,11 +871,11 @@ def test_time_units_with_timezone_roundtrip(calendar): assert result_calendar == calendar -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) def test_use_cftime_default_standard_calendar_in_range(calendar): numerical_dates = [0, 1] - units = 'days since 2000-01-01' - expected = pd.date_range('2000', periods=2) + units = "days since 2000-01-01" + expected = pd.date_range("2000", periods=2) with pytest.warns(None) as record: result = decode_cf_datetime(numerical_dates, units, calendar) @@ -850,17 +884,16 @@ def test_use_cftime_default_standard_calendar_in_range(calendar): @requires_cftime -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) -@pytest.mark.parametrize('units_year', [1500, 2500]) -def test_use_cftime_default_standard_calendar_out_of_range( - calendar, - units_year): +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2500]) +def test_use_cftime_default_standard_calendar_out_of_range(calendar, units_year): from cftime import num2date numerical_dates = [0, 1] - units = 'days since {}-01-01'.format(units_year) - expected = num2date(numerical_dates, units, calendar, - only_use_cftime_datetimes=True) + units = "days since {}-01-01".format(units_year) + expected = num2date( + numerical_dates, units, calendar, only_use_cftime_datetimes=True + ) with pytest.warns(SerializationWarning): result = decode_cf_datetime(numerical_dates, units, calendar) @@ -868,15 +901,16 @@ def test_use_cftime_default_standard_calendar_out_of_range( @requires_cftime -@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) -@pytest.mark.parametrize('units_year', [1500, 2000, 2500]) +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) def test_use_cftime_default_non_standard_calendar(calendar, units_year): from cftime import num2date numerical_dates = [0, 1] - units = 'days since {}-01-01'.format(units_year) - expected = num2date(numerical_dates, units, calendar, - only_use_cftime_datetimes=True) + units = "days since {}-01-01".format(units_year) + expected = num2date( + numerical_dates, units, calendar, only_use_cftime_datetimes=True + ) with pytest.warns(None) as record: result = decode_cf_datetime(numerical_dates, units, calendar) @@ -885,49 +919,48 @@ def test_use_cftime_default_non_standard_calendar(calendar, units_year): @requires_cftime -@pytest.mark.parametrize('calendar', _ALL_CALENDARS) -@pytest.mark.parametrize('units_year', [1500, 2000, 2500]) +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) def test_use_cftime_true(calendar, units_year): from cftime import num2date numerical_dates = [0, 1] - units = 'days since {}-01-01'.format(units_year) - expected = num2date(numerical_dates, units, calendar, - only_use_cftime_datetimes=True) + units = "days since {}-01-01".format(units_year) + expected = num2date( + numerical_dates, units, calendar, only_use_cftime_datetimes=True + ) with pytest.warns(None) as record: - result = decode_cf_datetime(numerical_dates, units, calendar, - use_cftime=True) + result = decode_cf_datetime(numerical_dates, units, calendar, use_cftime=True) np.testing.assert_array_equal(result, expected) assert not record -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) def test_use_cftime_false_standard_calendar_in_range(calendar): numerical_dates = [0, 1] - units = 'days since 2000-01-01' - expected = pd.date_range('2000', periods=2) + units = "days since 2000-01-01" + expected = pd.date_range("2000", periods=2) with pytest.warns(None) as record: - result = decode_cf_datetime(numerical_dates, units, calendar, - use_cftime=False) + result = decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) np.testing.assert_array_equal(result, expected) assert not record -@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) -@pytest.mark.parametrize('units_year', [1500, 2500]) +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2500]) def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year): numerical_dates = [0, 1] - units = 'days since {}-01-01'.format(units_year) + units = "days since {}-01-01".format(units_year) with pytest.raises(OutOfBoundsDatetime): decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) -@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) -@pytest.mark.parametrize('units_year', [1500, 2000, 2500]) +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) def test_use_cftime_false_non_standard_calendar(calendar, units_year): numerical_dates = [0, 1] - units = 'days since {}-01-01'.format(units_year) + units = "days since {}-01-01".format(units_year) with pytest.raises(OutOfBoundsDatetime): decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 8c9308466a4..e3801d02bc8 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -5,16 +5,19 @@ import numpy as np import pytest -from xarray import (DataArray, Dataset, concat, combine_by_coords, - combine_nested) +from xarray import DataArray, Dataset, concat, combine_by_coords, combine_nested from xarray import auto_combine from xarray.core import dtypes from xarray.core.combine import ( - _new_tile_id, _check_shape_tile_ids, _combine_all_along_first_dim, - _combine_nd, _infer_concat_order_from_positions, - _infer_concat_order_from_coords) - -from . import (assert_identical, assert_equal, raises_regex) + _new_tile_id, + _check_shape_tile_ids, + _combine_all_along_first_dim, + _combine_nd, + _infer_concat_order_from_positions, + _infer_concat_order_from_coords, +) + +from . import assert_identical, assert_equal, raises_regex from .test_dataset import create_test_data @@ -38,23 +41,38 @@ def test_2d(self): ds = create_test_data input = [[ds(0), ds(1)], [ds(2), ds(3)], [ds(4), ds(5)]] - expected = {(0, 0): ds(0), (0, 1): ds(1), - (1, 0): ds(2), (1, 1): ds(3), - (2, 0): ds(4), (2, 1): ds(5)} + expected = { + (0, 0): ds(0), + (0, 1): ds(1), + (1, 0): ds(2), + (1, 1): ds(3), + (2, 0): ds(4), + (2, 1): ds(5), + } actual = _infer_concat_order_from_positions(input) assert_combined_tile_ids_equal(expected, actual) def test_3d(self): ds = create_test_data - input = [[[ds(0), ds(1)], [ds(2), ds(3)], [ds(4), ds(5)]], - [[ds(6), ds(7)], [ds(8), ds(9)], [ds(10), ds(11)]]] - - expected = {(0, 0, 0): ds(0), (0, 0, 1): ds(1), - (0, 1, 0): ds(2), (0, 1, 1): ds(3), - (0, 2, 0): ds(4), (0, 2, 1): ds(5), - (1, 0, 0): ds(6), (1, 0, 1): ds(7), - (1, 1, 0): ds(8), (1, 1, 1): ds(9), - (1, 2, 0): ds(10), (1, 2, 1): ds(11)} + input = [ + [[ds(0), ds(1)], [ds(2), ds(3)], [ds(4), ds(5)]], + [[ds(6), ds(7)], [ds(8), ds(9)], [ds(10), ds(11)]], + ] + + expected = { + (0, 0, 0): ds(0), + (0, 0, 1): ds(1), + (0, 1, 0): ds(2), + (0, 1, 1): ds(3), + (0, 2, 0): ds(4), + (0, 2, 1): ds(5), + (1, 0, 0): ds(6), + (1, 0, 1): ds(7), + (1, 1, 0): ds(8), + (1, 1, 1): ds(9), + (1, 2, 0): ds(10), + (1, 2, 1): ds(11), + } actual = _infer_concat_order_from_positions(input) assert_combined_tile_ids_equal(expected, actual) @@ -112,102 +130,108 @@ def test_infer_from_datasets(self): class TestTileIDsFromCoords: def test_1d(self): - ds0 = Dataset({'x': [0, 1]}) - ds1 = Dataset({'x': [2, 3]}) + ds0 = Dataset({"x": [0, 1]}) + ds1 = Dataset({"x": [2, 3]}) expected = {(0,): ds0, (1,): ds1} actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['x'] + assert concat_dims == ["x"] def test_2d(self): - ds0 = Dataset({'x': [0, 1], 'y': [10, 20, 30]}) - ds1 = Dataset({'x': [2, 3], 'y': [10, 20, 30]}) - ds2 = Dataset({'x': [0, 1], 'y': [40, 50, 60]}) - ds3 = Dataset({'x': [2, 3], 'y': [40, 50, 60]}) - ds4 = Dataset({'x': [0, 1], 'y': [70, 80, 90]}) - ds5 = Dataset({'x': [2, 3], 'y': [70, 80, 90]}) - - expected = {(0, 0): ds0, (1, 0): ds1, - (0, 1): ds2, (1, 1): ds3, - (0, 2): ds4, (1, 2): ds5} - actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0, ds3, - ds5, ds2, ds4]) + ds0 = Dataset({"x": [0, 1], "y": [10, 20, 30]}) + ds1 = Dataset({"x": [2, 3], "y": [10, 20, 30]}) + ds2 = Dataset({"x": [0, 1], "y": [40, 50, 60]}) + ds3 = Dataset({"x": [2, 3], "y": [40, 50, 60]}) + ds4 = Dataset({"x": [0, 1], "y": [70, 80, 90]}) + ds5 = Dataset({"x": [2, 3], "y": [70, 80, 90]}) + + expected = { + (0, 0): ds0, + (1, 0): ds1, + (0, 1): ds2, + (1, 1): ds3, + (0, 2): ds4, + (1, 2): ds5, + } + actual, concat_dims = _infer_concat_order_from_coords( + [ds1, ds0, ds3, ds5, ds2, ds4] + ) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['x', 'y'] + assert concat_dims == ["x", "y"] def test_no_dimension_coords(self): - ds0 = Dataset({'foo': ('x', [0, 1])}) - ds1 = Dataset({'foo': ('x', [2, 3])}) + ds0 = Dataset({"foo": ("x", [0, 1])}) + ds1 = Dataset({"foo": ("x", [2, 3])}) with raises_regex(ValueError, "Could not find any dimension"): _infer_concat_order_from_coords([ds1, ds0]) def test_coord_not_monotonic(self): - ds0 = Dataset({'x': [0, 1]}) - ds1 = Dataset({'x': [3, 2]}) - with raises_regex(ValueError, "Coordinate variable x is neither " - "monotonically increasing nor"): + ds0 = Dataset({"x": [0, 1]}) + ds1 = Dataset({"x": [3, 2]}) + with raises_regex( + ValueError, + "Coordinate variable x is neither " "monotonically increasing nor", + ): _infer_concat_order_from_coords([ds1, ds0]) def test_coord_monotonically_decreasing(self): - ds0 = Dataset({'x': [3, 2]}) - ds1 = Dataset({'x': [1, 0]}) + ds0 = Dataset({"x": [3, 2]}) + ds1 = Dataset({"x": [1, 0]}) expected = {(0,): ds0, (1,): ds1} actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['x'] + assert concat_dims == ["x"] def test_no_concatenation_needed(self): - ds = Dataset({'foo': ('x', [0, 1])}) + ds = Dataset({"foo": ("x", [0, 1])}) expected = {(): ds} actual, concat_dims = _infer_concat_order_from_coords([ds]) assert_combined_tile_ids_equal(expected, actual) assert concat_dims == [] def test_2d_plus_bystander_dim(self): - ds0 = Dataset({'x': [0, 1], 'y': [10, 20, 30], 't': [0.1, 0.2]}) - ds1 = Dataset({'x': [2, 3], 'y': [10, 20, 30], 't': [0.1, 0.2]}) - ds2 = Dataset({'x': [0, 1], 'y': [40, 50, 60], 't': [0.1, 0.2]}) - ds3 = Dataset({'x': [2, 3], 'y': [40, 50, 60], 't': [0.1, 0.2]}) - - expected = {(0, 0): ds0, (1, 0): ds1, - (0, 1): ds2, (1, 1): ds3} - actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0, - ds3, ds2]) + ds0 = Dataset({"x": [0, 1], "y": [10, 20, 30], "t": [0.1, 0.2]}) + ds1 = Dataset({"x": [2, 3], "y": [10, 20, 30], "t": [0.1, 0.2]}) + ds2 = Dataset({"x": [0, 1], "y": [40, 50, 60], "t": [0.1, 0.2]}) + ds3 = Dataset({"x": [2, 3], "y": [40, 50, 60], "t": [0.1, 0.2]}) + + expected = {(0, 0): ds0, (1, 0): ds1, (0, 1): ds2, (1, 1): ds3} + actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0, ds3, ds2]) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['x', 'y'] + assert concat_dims == ["x", "y"] def test_string_coords(self): - ds0 = Dataset({'person': ['Alice', 'Bob']}) - ds1 = Dataset({'person': ['Caroline', 'Daniel']}) + ds0 = Dataset({"person": ["Alice", "Bob"]}) + ds1 = Dataset({"person": ["Caroline", "Daniel"]}) expected = {(0,): ds0, (1,): ds1} actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['person'] + assert concat_dims == ["person"] # Decided against natural sorting of string coords GH #2616 def test_lexicographic_sort_string_coords(self): - ds0 = Dataset({'simulation': ['run8', 'run9']}) - ds1 = Dataset({'simulation': ['run10', 'run11']}) + ds0 = Dataset({"simulation": ["run8", "run9"]}) + ds1 = Dataset({"simulation": ["run10", "run11"]}) expected = {(0,): ds1, (1,): ds0} actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['simulation'] + assert concat_dims == ["simulation"] def test_datetime_coords(self): - ds0 = Dataset({'time': [datetime(2000, 3, 6), datetime(2001, 3, 7)]}) - ds1 = Dataset({'time': [datetime(1999, 1, 1), datetime(1999, 2, 4)]}) + ds0 = Dataset({"time": [datetime(2000, 3, 6), datetime(2001, 3, 7)]}) + ds1 = Dataset({"time": [datetime(1999, 1, 1), datetime(1999, 2, 4)]}) expected = {(0,): ds1, (1,): ds0} actual, concat_dims = _infer_concat_order_from_coords([ds0, ds1]) assert_combined_tile_ids_equal(expected, actual) - assert concat_dims == ['time'] + assert concat_dims == ["time"] -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def create_combined_ids(): return _create_combined_ids @@ -215,8 +239,7 @@ def create_combined_ids(): def _create_combined_ids(shape): tile_ids = _create_tile_ids(shape) nums = range(len(tile_ids)) - return {tile_id: create_test_data(num) - for tile_id, num in zip(tile_ids, nums)} + return {tile_id: create_test_data(num) for tile_id, num in zip(tile_ids, nums)} def _create_tile_ids(shape): @@ -225,11 +248,10 @@ def _create_tile_ids(shape): class TestNewTileIDs: - @pytest.mark.parametrize("old_id, new_id", [((3, 0, 1), (0, 1)), - ((0, 0), (0,)), - ((1,), ()), - ((0,), ()), - ((1, 0), (0,))]) + @pytest.mark.parametrize( + "old_id, new_id", + [((3, 0, 1), (0, 1)), ((0, 0), (0,)), ((1,), ()), ((0,), ()), ((1, 0), (0,))], + ) def test_new_tile_id(self, old_id, new_id): ds = create_test_data assert _new_tile_id((old_id, ds)) == new_id @@ -244,15 +266,18 @@ def test_get_new_tile_ids(self, create_combined_ids): class TestCombineND: - @pytest.mark.parametrize("concat_dim", ['dim1', 'new_dim']) + @pytest.mark.parametrize("concat_dim", ["dim1", "new_dim"]) def test_concat_once(self, create_combined_ids, concat_dim): shape = (2,) combined_ids = create_combined_ids(shape) ds = create_test_data - result = _combine_all_along_first_dim(combined_ids, dim=concat_dim, - data_vars='all', - coords='different', - compat='no_conflicts') + result = _combine_all_along_first_dim( + combined_ids, + dim=concat_dim, + data_vars="all", + coords="different", + compat="no_conflicts", + ) expected_ds = concat([ds(0), ds(1)], dim=concat_dim) assert_combined_tile_ids_equal(result, {(): expected_ds}) @@ -260,30 +285,33 @@ def test_concat_once(self, create_combined_ids, concat_dim): def test_concat_only_first_dim(self, create_combined_ids): shape = (2, 3) combined_ids = create_combined_ids(shape) - result = _combine_all_along_first_dim(combined_ids, dim='dim1', - data_vars='all', - coords='different', - compat='no_conflicts') + result = _combine_all_along_first_dim( + combined_ids, + dim="dim1", + data_vars="all", + coords="different", + compat="no_conflicts", + ) ds = create_test_data - partway1 = concat([ds(0), ds(3)], dim='dim1') - partway2 = concat([ds(1), ds(4)], dim='dim1') - partway3 = concat([ds(2), ds(5)], dim='dim1') + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") expected_datasets = [partway1, partway2, partway3] expected = {(i,): ds for i, ds in enumerate(expected_datasets)} assert_combined_tile_ids_equal(result, expected) - @pytest.mark.parametrize("concat_dim", ['dim1', 'new_dim']) + @pytest.mark.parametrize("concat_dim", ["dim1", "new_dim"]) def test_concat_twice(self, create_combined_ids, concat_dim): shape = (2, 3) combined_ids = create_combined_ids(shape) - result = _combine_nd(combined_ids, concat_dims=['dim1', concat_dim]) + result = _combine_nd(combined_ids, concat_dims=["dim1", concat_dim]) ds = create_test_data - partway1 = concat([ds(0), ds(3)], dim='dim1') - partway2 = concat([ds(1), ds(4)], dim='dim1') - partway3 = concat([ds(2), ds(5)], dim='dim1') + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") expected = concat([partway1, partway2, partway3], dim=concat_dim) assert_equal(result, expected) @@ -293,109 +321,110 @@ class TestCheckShapeTileIDs: def test_check_depths(self): ds = create_test_data(0) combined_tile_ids = {(0,): ds, (0, 1): ds} - with raises_regex(ValueError, 'sub-lists do not have ' - 'consistent depths'): + with raises_regex(ValueError, "sub-lists do not have " "consistent depths"): _check_shape_tile_ids(combined_tile_ids) def test_check_lengths(self): ds = create_test_data(0) - combined_tile_ids = {(0, 0): ds, (0, 1): ds, (0, 2): ds, - (1, 0): ds, (1, 1): ds} - with raises_regex(ValueError, 'sub-lists do not have ' - 'consistent lengths'): + combined_tile_ids = {(0, 0): ds, (0, 1): ds, (0, 2): ds, (1, 0): ds, (1, 1): ds} + with raises_regex(ValueError, "sub-lists do not have " "consistent lengths"): _check_shape_tile_ids(combined_tile_ids) class TestNestedCombine: def test_nested_concat(self): - objs = [Dataset({'x': [0]}), Dataset({'x': [1]})] - expected = Dataset({'x': [0, 1]}) - actual = combine_nested(objs, concat_dim='x') + objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] + expected = Dataset({"x": [0, 1]}) + actual = combine_nested(objs, concat_dim="x") assert_identical(expected, actual) - actual = combine_nested(objs, concat_dim=['x']) + actual = combine_nested(objs, concat_dim=["x"]) assert_identical(expected, actual) actual = combine_nested([actual], concat_dim=None) assert_identical(expected, actual) - actual = combine_nested([actual], concat_dim='x') + actual = combine_nested([actual], concat_dim="x") assert_identical(expected, actual) - objs = [Dataset({'x': [0, 1]}), Dataset({'x': [2]})] - actual = combine_nested(objs, concat_dim='x') - expected = Dataset({'x': [0, 1, 2]}) + objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2]})] + actual = combine_nested(objs, concat_dim="x") + expected = Dataset({"x": [0, 1, 2]}) assert_identical(expected, actual) # ensure combine_nested handles non-sorted variables - objs = [Dataset(OrderedDict([('x', ('a', [0])), ('y', ('a', [0]))])), - Dataset(OrderedDict([('y', ('a', [1])), ('x', ('a', [1]))]))] - actual = combine_nested(objs, concat_dim='a') - expected = Dataset({'x': ('a', [0, 1]), 'y': ('a', [0, 1])}) + objs = [ + Dataset(OrderedDict([("x", ("a", [0])), ("y", ("a", [0]))])), + Dataset(OrderedDict([("y", ("a", [1])), ("x", ("a", [1]))])), + ] + actual = combine_nested(objs, concat_dim="a") + expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1])}) assert_identical(expected, actual) - objs = [Dataset({'x': [0], 'y': [0]}), Dataset({'x': [0]})] + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] with pytest.raises(KeyError): - combine_nested(objs, concat_dim='x') + combine_nested(objs, concat_dim="x") @pytest.mark.parametrize( "join, expected", [ - ('outer', Dataset({'x': [0, 1], 'y': [0, 1]})), - ('inner', Dataset({'x': [0, 1], 'y': []})), - ('left', Dataset({'x': [0, 1], 'y': [0]})), - ('right', Dataset({'x': [0, 1], 'y': [1]})), - ]) + ("outer", Dataset({"x": [0, 1], "y": [0, 1]})), + ("inner", Dataset({"x": [0, 1], "y": []})), + ("left", Dataset({"x": [0, 1], "y": [0]})), + ("right", Dataset({"x": [0, 1], "y": [1]})), + ], + ) def test_combine_nested_join(self, join, expected): - objs = [Dataset({'x': [0], 'y': [0]}), - Dataset({'x': [1], 'y': [1]})] - actual = combine_nested(objs, concat_dim='x', join=join) + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + actual = combine_nested(objs, concat_dim="x", join=join) assert_identical(expected, actual) def test_combine_nested_join_exact(self): - objs = [Dataset({'x': [0], 'y': [0]}), - Dataset({'x': [1], 'y': [1]})] - with raises_regex(ValueError, 'indexes along dimension'): - combine_nested(objs, concat_dim='x', join='exact') + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + with raises_regex(ValueError, "indexes along dimension"): + combine_nested(objs, concat_dim="x", join="exact") def test_empty_input(self): - assert_identical(Dataset(), combine_nested([], concat_dim='x')) + assert_identical(Dataset(), combine_nested([], concat_dim="x")) # Fails because of concat's weird treatment of dimension coords, see #2975 @pytest.mark.xfail def test_nested_concat_too_many_dims_at_once(self): - objs = [Dataset({'x': [0], 'y': [1]}), Dataset({'y': [0], 'x': [1]})] + objs = [Dataset({"x": [0], "y": [1]}), Dataset({"y": [0], "x": [1]})] with pytest.raises(ValueError, match="not equal across datasets"): - combine_nested(objs, concat_dim='x', coords='minimal') + combine_nested(objs, concat_dim="x", coords="minimal") def test_nested_concat_along_new_dim(self): - objs = [Dataset({'a': ('x', [10]), 'x': [0]}), - Dataset({'a': ('x', [20]), 'x': [0]})] - expected = Dataset({'a': (('t', 'x'), [[10], [20]]), 'x': [0]}) - actual = combine_nested(objs, concat_dim='t') + objs = [ + Dataset({"a": ("x", [10]), "x": [0]}), + Dataset({"a": ("x", [20]), "x": [0]}), + ] + expected = Dataset({"a": (("t", "x"), [[10], [20]]), "x": [0]}) + actual = combine_nested(objs, concat_dim="t") assert_identical(expected, actual) # Same but with a DataArray as new dim, see GH #1988 and #2647 - dim = DataArray([100, 150], name='baz', dims='baz') - expected = Dataset({'a': (('baz', 'x'), [[10], [20]]), - 'x': [0], 'baz': [100, 150]}) + dim = DataArray([100, 150], name="baz", dims="baz") + expected = Dataset( + {"a": (("baz", "x"), [[10], [20]]), "x": [0], "baz": [100, 150]} + ) actual = combine_nested(objs, concat_dim=dim) assert_identical(expected, actual) def test_nested_merge(self): - data = Dataset({'x': 0}) + data = Dataset({"x": 0}) actual = combine_nested([data, data, data], concat_dim=None) assert_identical(data, actual) - ds1 = Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) - ds2 = Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}) - expected = Dataset({'a': ('x', [1, 2, 3]), 'x': [0, 1, 2]}) + ds1 = Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = Dataset({"a": ("x", [2, 3]), "x": [1, 2]}) + expected = Dataset({"a": ("x", [1, 2, 3]), "x": [0, 1, 2]}) actual = combine_nested([ds1, ds2], concat_dim=None) assert_identical(expected, actual) actual = combine_nested([ds1, ds2], concat_dim=[None]) assert_identical(expected, actual) - tmp1 = Dataset({'x': 0}) - tmp2 = Dataset({'x': np.nan}) + tmp1 = Dataset({"x": 0}) + tmp2 = Dataset({"x": np.nan}) actual = combine_nested([tmp1, tmp2], concat_dim=None) assert_identical(tmp1, actual) actual = combine_nested([tmp1, tmp2], concat_dim=[None]) @@ -403,44 +432,42 @@ def test_nested_merge(self): # Single object, with a concat_dim explicitly provided # Test the issue reported in GH #1988 - objs = [Dataset({'x': 0, 'y': 1})] - dim = DataArray([100], name='baz', dims='baz') + objs = [Dataset({"x": 0, "y": 1})] + dim = DataArray([100], name="baz", dims="baz") actual = combine_nested(objs, concat_dim=[dim]) - expected = Dataset({'x': ('baz', [0]), 'y': ('baz', [1])}, - {'baz': [100]}) + expected = Dataset({"x": ("baz", [0]), "y": ("baz", [1])}, {"baz": [100]}) assert_identical(expected, actual) # Just making sure that auto_combine is doing what is # expected for non-scalar values, too. - objs = [Dataset({'x': ('z', [0, 1]), 'y': ('z', [1, 2])})] - dim = DataArray([100], name='baz', dims='baz') + objs = [Dataset({"x": ("z", [0, 1]), "y": ("z", [1, 2])})] + dim = DataArray([100], name="baz", dims="baz") actual = combine_nested(objs, concat_dim=[dim]) - expected = Dataset({'x': (('baz', 'z'), [[0, 1]]), - 'y': (('baz', 'z'), [[1, 2]])}, - {'baz': [100]}) + expected = Dataset( + {"x": (("baz", "z"), [[0, 1]]), "y": (("baz", "z"), [[1, 2]])}, + {"baz": [100]}, + ) assert_identical(expected, actual) def test_concat_multiple_dims(self): - objs = [[Dataset({'a': (('x', 'y'), [[0]])}), - Dataset({'a': (('x', 'y'), [[1]])})], - [Dataset({'a': (('x', 'y'), [[2]])}), - Dataset({'a': (('x', 'y'), [[3]])})]] - actual = combine_nested(objs, concat_dim=['x', 'y']) - expected = Dataset({'a': (('x', 'y'), [[0, 1], [2, 3]])}) + objs = [ + [Dataset({"a": (("x", "y"), [[0]])}), Dataset({"a": (("x", "y"), [[1]])})], + [Dataset({"a": (("x", "y"), [[2]])}), Dataset({"a": (("x", "y"), [[3]])})], + ] + actual = combine_nested(objs, concat_dim=["x", "y"]) + expected = Dataset({"a": (("x", "y"), [[0, 1], [2, 3]])}) assert_identical(expected, actual) def test_concat_name_symmetry(self): """Inspired by the discussion on GH issue #2777""" - da1 = DataArray(name='a', data=[[0]], dims=['x', 'y']) - da2 = DataArray(name='b', data=[[1]], dims=['x', 'y']) - da3 = DataArray(name='a', data=[[2]], dims=['x', 'y']) - da4 = DataArray(name='b', data=[[3]], dims=['x', 'y']) + da1 = DataArray(name="a", data=[[0]], dims=["x", "y"]) + da2 = DataArray(name="b", data=[[1]], dims=["x", "y"]) + da3 = DataArray(name="a", data=[[2]], dims=["x", "y"]) + da4 = DataArray(name="b", data=[[3]], dims=["x", "y"]) - x_first = combine_nested([[da1, da2], [da3, da4]], - concat_dim=['x', 'y']) - y_first = combine_nested([[da1, da3], [da2, da4]], - concat_dim=['y', 'x']) + x_first = combine_nested([[da1, da2], [da3, da4]], concat_dim=["x", "y"]) + y_first = combine_nested([[da1, da3], [da2, da4]], concat_dim=["y", "x"]) assert_identical(x_first, y_first) @@ -449,146 +476,147 @@ def test_concat_one_dim_merge_another(self): data1 = data.copy(deep=True) data2 = data.copy(deep=True) - objs = [[data1.var1.isel(dim2=slice(4)), - data2.var1.isel(dim2=slice(4, 9))], - [data1.var2.isel(dim2=slice(4)), - data2.var2.isel(dim2=slice(4, 9))]] + objs = [ + [data1.var1.isel(dim2=slice(4)), data2.var1.isel(dim2=slice(4, 9))], + [data1.var2.isel(dim2=slice(4)), data2.var2.isel(dim2=slice(4, 9))], + ] - expected = data[['var1', 'var2']] - actual = combine_nested(objs, concat_dim=[None, 'dim2']) + expected = data[["var1", "var2"]] + actual = combine_nested(objs, concat_dim=[None, "dim2"]) assert expected.identical(actual) def test_auto_combine_2d(self): ds = create_test_data - partway1 = concat([ds(0), ds(3)], dim='dim1') - partway2 = concat([ds(1), ds(4)], dim='dim1') - partway3 = concat([ds(2), ds(5)], dim='dim1') - expected = concat([partway1, partway2, partway3], dim='dim2') + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") + expected = concat([partway1, partway2, partway3], dim="dim2") datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] - result = combine_nested(datasets, concat_dim=['dim1', 'dim2']) + result = combine_nested(datasets, concat_dim=["dim1", "dim2"]) assert_equal(result, expected) def test_combine_nested_missing_data_new_dim(self): # Your data includes "time" and "station" dimensions, and each year's # data has a different set of stations. - datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), - Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] - expected = Dataset({'a': (('t', 'x'), - [[np.nan, 2, 3], [1, 2, np.nan]])}, - {'x': [0, 1, 2]}) - actual = combine_nested(datasets, concat_dim='t') + datasets = [ + Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + ] + expected = Dataset( + {"a": (("t", "x"), [[np.nan, 2, 3], [1, 2, np.nan]])}, {"x": [0, 1, 2]} + ) + actual = combine_nested(datasets, concat_dim="t") assert_identical(expected, actual) def test_invalid_hypercube_input(self): ds = create_test_data datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4)]] - with raises_regex(ValueError, 'sub-lists do not have ' - 'consistent lengths'): - combine_nested(datasets, concat_dim=['dim1', 'dim2']) + with raises_regex(ValueError, "sub-lists do not have " "consistent lengths"): + combine_nested(datasets, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [[ds(3), ds(4)]]] - with raises_regex(ValueError, 'sub-lists do not have ' - 'consistent depths'): - combine_nested(datasets, concat_dim=['dim1', 'dim2']) + with raises_regex(ValueError, "sub-lists do not have " "consistent depths"): + combine_nested(datasets, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [ds(3), ds(4)]] - with raises_regex(ValueError, 'concat_dims has length'): - combine_nested(datasets, concat_dim=['dim1']) + with raises_regex(ValueError, "concat_dims has length"): + combine_nested(datasets, concat_dim=["dim1"]) def test_merge_one_dim_concat_another(self): - objs = [[Dataset({'foo': ('x', [0, 1])}), - Dataset({'bar': ('x', [10, 20])})], - [Dataset({'foo': ('x', [2, 3])}), - Dataset({'bar': ('x', [30, 40])})]] - expected = Dataset({'foo': ('x', [0, 1, 2, 3]), - 'bar': ('x', [10, 20, 30, 40])}) + objs = [ + [Dataset({"foo": ("x", [0, 1])}), Dataset({"bar": ("x", [10, 20])})], + [Dataset({"foo": ("x", [2, 3])}), Dataset({"bar": ("x", [30, 40])})], + ] + expected = Dataset({"foo": ("x", [0, 1, 2, 3]), "bar": ("x", [10, 20, 30, 40])}) - actual = combine_nested(objs, concat_dim=['x', None], compat='equals') + actual = combine_nested(objs, concat_dim=["x", None], compat="equals") assert_identical(expected, actual) # Proving it works symmetrically - objs = [[Dataset({'foo': ('x', [0, 1])}), - Dataset({'foo': ('x', [2, 3])})], - [Dataset({'bar': ('x', [10, 20])}), - Dataset({'bar': ('x', [30, 40])})]] - actual = combine_nested(objs, concat_dim=[None, 'x'], compat='equals') + objs = [ + [Dataset({"foo": ("x", [0, 1])}), Dataset({"foo": ("x", [2, 3])})], + [Dataset({"bar": ("x", [10, 20])}), Dataset({"bar": ("x", [30, 40])})], + ] + actual = combine_nested(objs, concat_dim=[None, "x"], compat="equals") assert_identical(expected, actual) def test_combine_concat_over_redundant_nesting(self): - objs = [[Dataset({'x': [0]}), Dataset({'x': [1]})]] - actual = combine_nested(objs, concat_dim=[None, 'x']) - expected = Dataset({'x': [0, 1]}) + objs = [[Dataset({"x": [0]}), Dataset({"x": [1]})]] + actual = combine_nested(objs, concat_dim=[None, "x"]) + expected = Dataset({"x": [0, 1]}) assert_identical(expected, actual) - objs = [[Dataset({'x': [0]})], [Dataset({'x': [1]})]] - actual = combine_nested(objs, concat_dim=['x', None]) - expected = Dataset({'x': [0, 1]}) + objs = [[Dataset({"x": [0]})], [Dataset({"x": [1]})]] + actual = combine_nested(objs, concat_dim=["x", None]) + expected = Dataset({"x": [0, 1]}) assert_identical(expected, actual) - objs = [[Dataset({'x': [0]})]] + objs = [[Dataset({"x": [0]})]] actual = combine_nested(objs, concat_dim=[None, None]) - expected = Dataset({'x': [0]}) + expected = Dataset({"x": [0]}) assert_identical(expected, actual) def test_combine_nested_but_need_auto_combine(self): - objs = [Dataset({'x': [0, 1]}), Dataset({'x': [2], 'wall': [0]})] - with raises_regex(ValueError, 'cannot be combined'): - combine_nested(objs, concat_dim='x') + objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2], "wall": [0]})] + with raises_regex(ValueError, "cannot be combined"): + combine_nested(objs, concat_dim="x") - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_combine_nested_fill_value(self, fill_value): - datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), - Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + datasets = [ + Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + ] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = Dataset({'a': (('t', 'x'), - [[fill_value, 2, 3], [1, 2, fill_value]])}, - {'x': [0, 1, 2]}) - actual = combine_nested(datasets, concat_dim='t', - fill_value=fill_value) + expected = Dataset( + {"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])}, + {"x": [0, 1, 2]}, + ) + actual = combine_nested(datasets, concat_dim="t", fill_value=fill_value) assert_identical(expected, actual) class TestCombineAuto: def test_combine_by_coords(self): - objs = [Dataset({'x': [0]}), Dataset({'x': [1]})] + objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] actual = combine_by_coords(objs) - expected = Dataset({'x': [0, 1]}) + expected = Dataset({"x": [0, 1]}) assert_identical(expected, actual) actual = combine_by_coords([actual]) assert_identical(expected, actual) - objs = [Dataset({'x': [0, 1]}), Dataset({'x': [2]})] + objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2]})] actual = combine_by_coords(objs) - expected = Dataset({'x': [0, 1, 2]}) + expected = Dataset({"x": [0, 1, 2]}) assert_identical(expected, actual) # ensure auto_combine handles non-sorted variables - objs = [Dataset({'x': ('a', [0]), 'y': ('a', [0]), 'a': [0]}), - Dataset({'x': ('a', [1]), 'y': ('a', [1]), 'a': [1]})] + objs = [ + Dataset({"x": ("a", [0]), "y": ("a", [0]), "a": [0]}), + Dataset({"x": ("a", [1]), "y": ("a", [1]), "a": [1]}), + ] actual = combine_by_coords(objs) - expected = Dataset({'x': ('a', [0, 1]), 'y': ('a', [0, 1]), - 'a': [0, 1]}) + expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1]), "a": [0, 1]}) assert_identical(expected, actual) - objs = [Dataset({'x': [0], 'y': [0]}), Dataset({'y': [1], 'x': [1]})] + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"y": [1], "x": [1]})] actual = combine_by_coords(objs) - expected = Dataset({'x': [0, 1], 'y': [0, 1]}) + expected = Dataset({"x": [0, 1], "y": [0, 1]}) assert_equal(actual, expected) - objs = [Dataset({'x': 0}), Dataset({'x': 1})] - with raises_regex(ValueError, 'Could not find any dimension ' - 'coordinates'): + objs = [Dataset({"x": 0}), Dataset({"x": 1})] + with raises_regex(ValueError, "Could not find any dimension " "coordinates"): combine_by_coords(objs) - objs = [Dataset({'x': [0], 'y': [0]}), Dataset({'x': [0]})] - with raises_regex(ValueError, 'Every dimension needs a coordinate'): + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] + with raises_regex(ValueError, "Every dimension needs a coordinate"): combine_by_coords(objs) def test_empty_input(self): @@ -597,22 +625,21 @@ def test_empty_input(self): @pytest.mark.parametrize( "join, expected", [ - ('outer', Dataset({'x': [0, 1], 'y': [0, 1]})), - ('inner', Dataset({'x': [0, 1], 'y': []})), - ('left', Dataset({'x': [0, 1], 'y': [0]})), - ('right', Dataset({'x': [0, 1], 'y': [1]})), - ]) + ("outer", Dataset({"x": [0, 1], "y": [0, 1]})), + ("inner", Dataset({"x": [0, 1], "y": []})), + ("left", Dataset({"x": [0, 1], "y": [0]})), + ("right", Dataset({"x": [0, 1], "y": [1]})), + ], + ) def test_combine_coords_join(self, join, expected): - objs = [Dataset({'x': [0], 'y': [0]}), - Dataset({'x': [1], 'y': [1]})] - actual = combine_nested(objs, concat_dim='x', join=join) + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + actual = combine_nested(objs, concat_dim="x", join=join) assert_identical(expected, actual) def test_combine_coords_join_exact(self): - objs = [Dataset({'x': [0], 'y': [0]}), - Dataset({'x': [1], 'y': [1]})] - with raises_regex(ValueError, 'indexes along dimension'): - combine_nested(objs, concat_dim='x', join='exact') + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + with raises_regex(ValueError, "indexes along dimension"): + combine_nested(objs, concat_dim="x", join="exact") def test_infer_order_from_coords(self): data = create_test_data() @@ -624,18 +651,24 @@ def test_infer_order_from_coords(self): def test_combine_leaving_bystander_dimensions(self): # Check non-monotonic bystander dimension coord doesn't raise # ValueError on combine (https://github.com/pydata/xarray/issues/3150) - ycoord = ['a', 'c', 'b'] + ycoord = ["a", "c", "b"] data = np.random.rand(7, 3) - ds1 = Dataset(data_vars=dict(data=(['x', 'y'], data[:3, :])), - coords=dict(x=[1, 2, 3], y=ycoord)) + ds1 = Dataset( + data_vars=dict(data=(["x", "y"], data[:3, :])), + coords=dict(x=[1, 2, 3], y=ycoord), + ) - ds2 = Dataset(data_vars=dict(data=(['x', 'y'], data[3:, :])), - coords=dict(x=[4, 5, 6, 7], y=ycoord)) + ds2 = Dataset( + data_vars=dict(data=(["x", "y"], data[3:, :])), + coords=dict(x=[4, 5, 6, 7], y=ycoord), + ) - expected = Dataset(data_vars=dict(data=(['x', 'y'], data)), - coords=dict(x=[1, 2, 3, 4, 5, 6, 7], y=ycoord)) + expected = Dataset( + data_vars=dict(data=(["x", "y"], data)), + coords=dict(x=[1, 2, 3, 4, 5, 6, 7], y=ycoord), + ) actual = combine_by_coords((ds1, ds2)) assert_identical(expected, actual) @@ -643,43 +676,45 @@ def test_combine_leaving_bystander_dimensions(self): def test_combine_by_coords_previously_failed(self): # In the above scenario, one file is missing, containing the data for # one year's data for one variable. - datasets = [Dataset({'a': ('x', [0]), 'x': [0]}), - Dataset({'b': ('x', [0]), 'x': [0]}), - Dataset({'a': ('x', [1]), 'x': [1]})] - expected = Dataset({'a': ('x', [0, 1]), 'b': ('x', [0, np.nan])}, - {'x': [0, 1]}) + datasets = [ + Dataset({"a": ("x", [0]), "x": [0]}), + Dataset({"b": ("x", [0]), "x": [0]}), + Dataset({"a": ("x", [1]), "x": [1]}), + ] + expected = Dataset({"a": ("x", [0, 1]), "b": ("x", [0, np.nan])}, {"x": [0, 1]}) actual = combine_by_coords(datasets) assert_identical(expected, actual) def test_combine_by_coords_still_fails(self): # concat can't handle new variables (yet): # https://github.com/pydata/xarray/issues/508 - datasets = [Dataset({'x': 0}, {'y': 0}), - Dataset({'x': 1}, {'y': 1, 'z': 1})] + datasets = [Dataset({"x": 0}, {"y": 0}), Dataset({"x": 1}, {"y": 1, "z": 1})] with pytest.raises(ValueError): - combine_by_coords(datasets, 'y') + combine_by_coords(datasets, "y") def test_combine_by_coords_no_concat(self): - objs = [Dataset({'x': 0}), Dataset({'y': 1})] + objs = [Dataset({"x": 0}), Dataset({"y": 1})] actual = combine_by_coords(objs) - expected = Dataset({'x': 0, 'y': 1}) + expected = Dataset({"x": 0, "y": 1}) assert_identical(expected, actual) - objs = [Dataset({'x': 0, 'y': 1}), Dataset({'y': np.nan, 'z': 2})] + objs = [Dataset({"x": 0, "y": 1}), Dataset({"y": np.nan, "z": 2})] actual = combine_by_coords(objs) - expected = Dataset({'x': 0, 'y': 1, 'z': 2}) + expected = Dataset({"x": 0, "y": 1, "z": 2}) assert_identical(expected, actual) def test_check_for_impossible_ordering(self): - ds0 = Dataset({'x': [0, 1, 5]}) - ds1 = Dataset({'x': [2, 3]}) - with raises_regex(ValueError, "does not have monotonic global indexes" - " along dimension x"): + ds0 = Dataset({"x": [0, 1, 5]}) + ds1 = Dataset({"x": [2, 3]}) + with raises_regex( + ValueError, "does not have monotonic global indexes" " along dimension x" + ): combine_by_coords([ds1, ds0]) -@pytest.mark.filterwarnings("ignore:In xarray version 0.13 `auto_combine` " - "will be deprecated") +@pytest.mark.filterwarnings( + "ignore:In xarray version 0.13 `auto_combine` " "will be deprecated" +) @pytest.mark.filterwarnings("ignore:Also `open_mfdataset` will no longer") @pytest.mark.filterwarnings("ignore:The datasets supplied") class TestAutoCombineOldAPI: @@ -687,122 +722,131 @@ class TestAutoCombineOldAPI: Set of tests which check that old 1-dimensional auto_combine behaviour is still satisfied. #2616 """ + def test_auto_combine(self): - objs = [Dataset({'x': [0]}), Dataset({'x': [1]})] + objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] actual = auto_combine(objs) - expected = Dataset({'x': [0, 1]}) + expected = Dataset({"x": [0, 1]}) assert_identical(expected, actual) actual = auto_combine([actual]) assert_identical(expected, actual) - objs = [Dataset({'x': [0, 1]}), Dataset({'x': [2]})] + objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2]})] actual = auto_combine(objs) - expected = Dataset({'x': [0, 1, 2]}) + expected = Dataset({"x": [0, 1, 2]}) assert_identical(expected, actual) # ensure auto_combine handles non-sorted variables - objs = [Dataset(OrderedDict([('x', ('a', [0])), ('y', ('a', [0]))])), - Dataset(OrderedDict([('y', ('a', [1])), ('x', ('a', [1]))]))] + objs = [ + Dataset(OrderedDict([("x", ("a", [0])), ("y", ("a", [0]))])), + Dataset(OrderedDict([("y", ("a", [1])), ("x", ("a", [1]))])), + ] actual = auto_combine(objs) - expected = Dataset({'x': ('a', [0, 1]), 'y': ('a', [0, 1])}) + expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1])}) assert_identical(expected, actual) - objs = [Dataset({'x': [0], 'y': [0]}), Dataset({'y': [1], 'x': [1]})] - with raises_regex(ValueError, 'too many .* dimensions'): + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"y": [1], "x": [1]})] + with raises_regex(ValueError, "too many .* dimensions"): auto_combine(objs) - objs = [Dataset({'x': 0}), Dataset({'x': 1})] - with raises_regex(ValueError, 'cannot infer dimension'): + objs = [Dataset({"x": 0}), Dataset({"x": 1})] + with raises_regex(ValueError, "cannot infer dimension"): auto_combine(objs) - objs = [Dataset({'x': [0], 'y': [0]}), Dataset({'x': [0]})] + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] with pytest.raises(KeyError): auto_combine(objs) def test_auto_combine_previously_failed(self): # In the above scenario, one file is missing, containing the data for # one year's data for one variable. - datasets = [Dataset({'a': ('x', [0]), 'x': [0]}), - Dataset({'b': ('x', [0]), 'x': [0]}), - Dataset({'a': ('x', [1]), 'x': [1]})] - expected = Dataset({'a': ('x', [0, 1]), 'b': ('x', [0, np.nan])}, - {'x': [0, 1]}) + datasets = [ + Dataset({"a": ("x", [0]), "x": [0]}), + Dataset({"b": ("x", [0]), "x": [0]}), + Dataset({"a": ("x", [1]), "x": [1]}), + ] + expected = Dataset({"a": ("x", [0, 1]), "b": ("x", [0, np.nan])}, {"x": [0, 1]}) actual = auto_combine(datasets) assert_identical(expected, actual) # Your data includes "time" and "station" dimensions, and each year's # data has a different set of stations. - datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), - Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] - expected = Dataset({'a': (('t', 'x'), - [[np.nan, 2, 3], [1, 2, np.nan]])}, - {'x': [0, 1, 2]}) - actual = auto_combine(datasets, concat_dim='t') + datasets = [ + Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + ] + expected = Dataset( + {"a": (("t", "x"), [[np.nan, 2, 3], [1, 2, np.nan]])}, {"x": [0, 1, 2]} + ) + actual = auto_combine(datasets, concat_dim="t") assert_identical(expected, actual) def test_auto_combine_still_fails(self): # concat can't handle new variables (yet): # https://github.com/pydata/xarray/issues/508 - datasets = [Dataset({'x': 0}, {'y': 0}), - Dataset({'x': 1}, {'y': 1, 'z': 1})] + datasets = [Dataset({"x": 0}, {"y": 0}), Dataset({"x": 1}, {"y": 1, "z": 1})] with pytest.raises(ValueError): - auto_combine(datasets, 'y') + auto_combine(datasets, "y") def test_auto_combine_no_concat(self): - objs = [Dataset({'x': 0}), Dataset({'y': 1})] + objs = [Dataset({"x": 0}), Dataset({"y": 1})] actual = auto_combine(objs) - expected = Dataset({'x': 0, 'y': 1}) + expected = Dataset({"x": 0, "y": 1}) assert_identical(expected, actual) - objs = [Dataset({'x': 0, 'y': 1}), Dataset({'y': np.nan, 'z': 2})] + objs = [Dataset({"x": 0, "y": 1}), Dataset({"y": np.nan, "z": 2})] actual = auto_combine(objs) - expected = Dataset({'x': 0, 'y': 1, 'z': 2}) + expected = Dataset({"x": 0, "y": 1, "z": 2}) assert_identical(expected, actual) - data = Dataset({'x': 0}) + data = Dataset({"x": 0}) actual = auto_combine([data, data, data], concat_dim=None) assert_identical(data, actual) # Single object, with a concat_dim explicitly provided # Test the issue reported in GH #1988 - objs = [Dataset({'x': 0, 'y': 1})] - dim = DataArray([100], name='baz', dims='baz') + objs = [Dataset({"x": 0, "y": 1})] + dim = DataArray([100], name="baz", dims="baz") actual = auto_combine(objs, concat_dim=dim) - expected = Dataset({'x': ('baz', [0]), 'y': ('baz', [1])}, - {'baz': [100]}) + expected = Dataset({"x": ("baz", [0]), "y": ("baz", [1])}, {"baz": [100]}) assert_identical(expected, actual) # Just making sure that auto_combine is doing what is # expected for non-scalar values, too. - objs = [Dataset({'x': ('z', [0, 1]), 'y': ('z', [1, 2])})] - dim = DataArray([100], name='baz', dims='baz') + objs = [Dataset({"x": ("z", [0, 1]), "y": ("z", [1, 2])})] + dim = DataArray([100], name="baz", dims="baz") actual = auto_combine(objs, concat_dim=dim) - expected = Dataset({'x': (('baz', 'z'), [[0, 1]]), - 'y': (('baz', 'z'), [[1, 2]])}, - {'baz': [100]}) + expected = Dataset( + {"x": (("baz", "z"), [[0, 1]]), "y": (("baz", "z"), [[1, 2]])}, + {"baz": [100]}, + ) assert_identical(expected, actual) def test_auto_combine_order_by_appearance_not_coords(self): - objs = [Dataset({'foo': ('x', [0])}, coords={'x': ('x', [1])}), - Dataset({'foo': ('x', [1])}, coords={'x': ('x', [0])})] + objs = [ + Dataset({"foo": ("x", [0])}, coords={"x": ("x", [1])}), + Dataset({"foo": ("x", [1])}, coords={"x": ("x", [0])}), + ] actual = auto_combine(objs) - expected = Dataset({'foo': ('x', [0, 1])}, - coords={'x': ('x', [1, 0])}) + expected = Dataset({"foo": ("x", [0, 1])}, coords={"x": ("x", [1, 0])}) assert_identical(expected, actual) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_auto_combine_fill_value(self, fill_value): - datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), - Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + datasets = [ + Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + ] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = Dataset({'a': (('t', 'x'), - [[fill_value, 2, 3], [1, 2, fill_value]])}, - {'x': [0, 1, 2]}) - actual = auto_combine(datasets, concat_dim='t', fill_value=fill_value) + expected = Dataset( + {"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])}, + {"x": [0, 1, 2]}, + ) + actual = auto_combine(datasets, concat_dim="t", fill_value=fill_value) assert_identical(expected, actual) @@ -811,26 +855,26 @@ class TestAutoCombineDeprecation: Set of tests to check that FutureWarnings are correctly raised until the deprecation cycle is complete. #2616 """ + def test_auto_combine_with_concat_dim(self): - objs = [Dataset({'x': [0]}), Dataset({'x': [1]})] + objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] with pytest.warns(FutureWarning, match="`concat_dim`"): - auto_combine(objs, concat_dim='x') + auto_combine(objs, concat_dim="x") def test_auto_combine_with_merge_and_concat(self): - objs = [Dataset({'x': [0]}), - Dataset({'x': [1]}), - Dataset({'z': ((), 99)})] + objs = [Dataset({"x": [0]}), Dataset({"x": [1]}), Dataset({"z": ((), 99)})] with pytest.warns(FutureWarning, match="require both concatenation"): auto_combine(objs) def test_auto_combine_with_coords(self): - objs = [Dataset({'foo': ('x', [0])}, coords={'x': ('x', [0])}), - Dataset({'foo': ('x', [1])}, coords={'x': ('x', [1])})] + objs = [ + Dataset({"foo": ("x", [0])}, coords={"x": ("x", [0])}), + Dataset({"foo": ("x", [1])}, coords={"x": ("x", [1])}), + ] with pytest.warns(FutureWarning, match="supplied have global"): auto_combine(objs) def test_auto_combine_without_coords(self): - objs = [Dataset({'foo': ('x', [0])}), - Dataset({'foo': ('x', [1])})] + objs = [Dataset({"foo": ("x", [0])}), Dataset({"foo": ("x", [1])})] with pytest.warns(FutureWarning, match="supplied do not have global"): auto_combine(objs) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index f2f4be2e082..784a988b7cc 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -11,46 +11,52 @@ import xarray as xr from xarray.core.computation import ( - _UFuncSignature, apply_ufunc, broadcast_compat_data, collect_dict_values, - join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, - unified_dim_sizes) + _UFuncSignature, + apply_ufunc, + broadcast_compat_data, + collect_dict_values, + join_dict_keys, + ordered_set_intersection, + ordered_set_union, + result_name, + unified_dim_sizes, +) from . import has_dask, raises_regex, requires_dask def assert_identical(a, b): - if hasattr(a, 'identical'): - msg = 'not identical:\n%r\n%r' % (a, b) + if hasattr(a, "identical"): + msg = "not identical:\n%r\n%r" % (a, b) assert a.identical(b), msg else: assert_array_equal(a, b) def test_signature_properties(): - sig = _UFuncSignature([['x'], ['x', 'y']], [['z']]) - assert sig.input_core_dims == (('x',), ('x', 'y')) - assert sig.output_core_dims == (('z',),) - assert sig.all_input_core_dims == frozenset(['x', 'y']) - assert sig.all_output_core_dims == frozenset(['z']) + sig = _UFuncSignature([["x"], ["x", "y"]], [["z"]]) + assert sig.input_core_dims == (("x",), ("x", "y")) + assert sig.output_core_dims == (("z",),) + assert sig.all_input_core_dims == frozenset(["x", "y"]) + assert sig.all_output_core_dims == frozenset(["z"]) assert sig.num_inputs == 2 assert sig.num_outputs == 1 - assert str(sig) == '(x),(x,y)->(z)' - assert sig.to_gufunc_string() == '(dim0),(dim0,dim1)->(dim2)' + assert str(sig) == "(x),(x,y)->(z)" + assert sig.to_gufunc_string() == "(dim0),(dim0,dim1)->(dim2)" # dimension names matter - assert _UFuncSignature([['x']]) != _UFuncSignature([['y']]) + assert _UFuncSignature([["x"]]) != _UFuncSignature([["y"]]) def test_result_name(): - class Named: def __init__(self, name=None): self.name = name assert result_name([1, 2]) is None assert result_name([Named()]) is None - assert result_name([Named('foo'), 2]) == 'foo' - assert result_name([Named('foo'), Named('bar')]) is None - assert result_name([Named('foo'), Named()]) is None + assert result_name([Named("foo"), 2]) == "foo" + assert result_name([Named("foo"), Named("bar")]) is None + assert result_name([Named("foo"), Named()]) is None def test_ordered_set_union(): @@ -67,21 +73,21 @@ def test_ordered_set_intersection(): def test_join_dict_keys(): - dicts = [OrderedDict.fromkeys(keys) for keys in [['x', 'y'], ['y', 'z']]] - assert list(join_dict_keys(dicts, 'left')) == ['x', 'y'] - assert list(join_dict_keys(dicts, 'right')) == ['y', 'z'] - assert list(join_dict_keys(dicts, 'inner')) == ['y'] - assert list(join_dict_keys(dicts, 'outer')) == ['x', 'y', 'z'] + dicts = [OrderedDict.fromkeys(keys) for keys in [["x", "y"], ["y", "z"]]] + assert list(join_dict_keys(dicts, "left")) == ["x", "y"] + assert list(join_dict_keys(dicts, "right")) == ["y", "z"] + assert list(join_dict_keys(dicts, "inner")) == ["y"] + assert list(join_dict_keys(dicts, "outer")) == ["x", "y", "z"] with pytest.raises(ValueError): - join_dict_keys(dicts, 'exact') + join_dict_keys(dicts, "exact") with pytest.raises(KeyError): - join_dict_keys(dicts, 'foobar') + join_dict_keys(dicts, "foobar") def test_collect_dict_values(): - dicts = [{'x': 1, 'y': 2, 'z': 3}, {'z': 4}, 5] + dicts = [{"x": 1, "y": 2, "z": 3}, {"z": 4}, 5] expected = [[1, 0, 5], [2, 0, 5], [3, 4, 5]] - collected = collect_dict_values(dicts, ['x', 'y', 'z'], fill_value=0) + collected = collect_dict_values(dicts, ["x", "y", "z"], fill_value=0) assert collected == expected @@ -91,18 +97,18 @@ def identity(x): def test_apply_identity(): array = np.arange(10) - variable = xr.Variable('x', array) - data_array = xr.DataArray(variable, [('x', -array)]) - dataset = xr.Dataset({'y': variable}, {'x': -array}) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) apply_identity = functools.partial(apply_ufunc, identity) assert_identical(array, apply_identity(array)) assert_identical(variable, apply_identity(variable)) assert_identical(data_array, apply_identity(data_array)) - assert_identical(data_array, apply_identity(data_array.groupby('x'))) + assert_identical(data_array, apply_identity(data_array.groupby("x"))) assert_identical(dataset, apply_identity(dataset)) - assert_identical(dataset, apply_identity(dataset.groupby('x'))) + assert_identical(dataset, apply_identity(dataset.groupby("x"))) def add(a, b): @@ -111,14 +117,14 @@ def add(a, b): def test_apply_two_inputs(): array = np.array([1, 2, 3]) - variable = xr.Variable('x', array) - data_array = xr.DataArray(variable, [('x', -array)]) - dataset = xr.Dataset({'y': variable}, {'x': -array}) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) zero_array = np.zeros_like(array) - zero_variable = xr.Variable('x', zero_array) - zero_data_array = xr.DataArray(zero_variable, [('x', -array)]) - zero_dataset = xr.Dataset({'y': zero_variable}, {'x': -array}) + zero_variable = xr.Variable("x", zero_array) + zero_data_array = xr.DataArray(zero_variable, [("x", -array)]) + zero_dataset = xr.Dataset({"y": zero_variable}, {"x": -array}) assert_identical(array, add(array, zero_array)) assert_identical(array, add(zero_array, array)) @@ -144,28 +150,28 @@ def test_apply_two_inputs(): assert_identical(dataset, add(zero_data_array, dataset)) assert_identical(dataset, add(zero_dataset, dataset)) - assert_identical(data_array, add(data_array.groupby('x'), zero_data_array)) - assert_identical(data_array, add(zero_data_array, data_array.groupby('x'))) + assert_identical(data_array, add(data_array.groupby("x"), zero_data_array)) + assert_identical(data_array, add(zero_data_array, data_array.groupby("x"))) - assert_identical(dataset, add(data_array.groupby('x'), zero_dataset)) - assert_identical(dataset, add(zero_dataset, data_array.groupby('x'))) + assert_identical(dataset, add(data_array.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_dataset, data_array.groupby("x"))) - assert_identical(dataset, add(dataset.groupby('x'), zero_data_array)) - assert_identical(dataset, add(dataset.groupby('x'), zero_dataset)) - assert_identical(dataset, add(zero_data_array, dataset.groupby('x'))) - assert_identical(dataset, add(zero_dataset, dataset.groupby('x'))) + assert_identical(dataset, add(dataset.groupby("x"), zero_data_array)) + assert_identical(dataset, add(dataset.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_data_array, dataset.groupby("x"))) + assert_identical(dataset, add(zero_dataset, dataset.groupby("x"))) def test_apply_1d_and_0d(): array = np.array([1, 2, 3]) - variable = xr.Variable('x', array) - data_array = xr.DataArray(variable, [('x', -array)]) - dataset = xr.Dataset({'y': variable}, {'x': -array}) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) zero_array = 0 zero_variable = xr.Variable((), zero_array) zero_data_array = xr.DataArray(zero_variable) - zero_dataset = xr.Dataset({'y': zero_variable}) + zero_dataset = xr.Dataset({"y": zero_variable}) assert_identical(array, add(array, zero_array)) assert_identical(array, add(zero_array, array)) @@ -191,27 +197,28 @@ def test_apply_1d_and_0d(): assert_identical(dataset, add(zero_data_array, dataset)) assert_identical(dataset, add(zero_dataset, dataset)) - assert_identical(data_array, add(data_array.groupby('x'), zero_data_array)) - assert_identical(data_array, add(zero_data_array, data_array.groupby('x'))) + assert_identical(data_array, add(data_array.groupby("x"), zero_data_array)) + assert_identical(data_array, add(zero_data_array, data_array.groupby("x"))) - assert_identical(dataset, add(data_array.groupby('x'), zero_dataset)) - assert_identical(dataset, add(zero_dataset, data_array.groupby('x'))) + assert_identical(dataset, add(data_array.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_dataset, data_array.groupby("x"))) - assert_identical(dataset, add(dataset.groupby('x'), zero_data_array)) - assert_identical(dataset, add(dataset.groupby('x'), zero_dataset)) - assert_identical(dataset, add(zero_data_array, dataset.groupby('x'))) - assert_identical(dataset, add(zero_dataset, dataset.groupby('x'))) + assert_identical(dataset, add(dataset.groupby("x"), zero_data_array)) + assert_identical(dataset, add(dataset.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_data_array, dataset.groupby("x"))) + assert_identical(dataset, add(zero_dataset, dataset.groupby("x"))) def test_apply_two_outputs(): array = np.arange(5) - variable = xr.Variable('x', array) - data_array = xr.DataArray(variable, [('x', -array)]) - dataset = xr.Dataset({'y': variable}, {'x': -array}) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) def twice(obj): def func(x): return (x, x) + return apply_ufunc(func, obj, output_core_dims=[[], []]) out0, out1 = twice(array) @@ -230,49 +237,46 @@ def func(x): assert_identical(out0, dataset) assert_identical(out1, dataset) - out0, out1 = twice(data_array.groupby('x')) + out0, out1 = twice(data_array.groupby("x")) assert_identical(out0, data_array) assert_identical(out1, data_array) - out0, out1 = twice(dataset.groupby('x')) + out0, out1 = twice(dataset.groupby("x")) assert_identical(out0, dataset) assert_identical(out1, dataset) def test_apply_input_core_dimension(): - def first_element(obj, dim): def func(x): return x[..., 0] + return apply_ufunc(func, obj, input_core_dims=[[dim]]) array = np.array([[1, 2], [3, 4]]) - variable = xr.Variable(['x', 'y'], array) - data_array = xr.DataArray(variable, {'x': ['a', 'b'], 'y': [-1, -2]}) - dataset = xr.Dataset({'data': data_array}) + variable = xr.Variable(["x", "y"], array) + data_array = xr.DataArray(variable, {"x": ["a", "b"], "y": [-1, -2]}) + dataset = xr.Dataset({"data": data_array}) - expected_variable_x = xr.Variable(['y'], [1, 2]) - expected_data_array_x = xr.DataArray(expected_variable_x, {'y': [-1, -2]}) - expected_dataset_x = xr.Dataset({'data': expected_data_array_x}) + expected_variable_x = xr.Variable(["y"], [1, 2]) + expected_data_array_x = xr.DataArray(expected_variable_x, {"y": [-1, -2]}) + expected_dataset_x = xr.Dataset({"data": expected_data_array_x}) - expected_variable_y = xr.Variable(['x'], [1, 3]) - expected_data_array_y = xr.DataArray(expected_variable_y, - {'x': ['a', 'b']}) - expected_dataset_y = xr.Dataset({'data': expected_data_array_y}) + expected_variable_y = xr.Variable(["x"], [1, 3]) + expected_data_array_y = xr.DataArray(expected_variable_y, {"x": ["a", "b"]}) + expected_dataset_y = xr.Dataset({"data": expected_data_array_y}) - assert_identical(expected_variable_x, first_element(variable, 'x')) - assert_identical(expected_variable_y, first_element(variable, 'y')) + assert_identical(expected_variable_x, first_element(variable, "x")) + assert_identical(expected_variable_y, first_element(variable, "y")) - assert_identical(expected_data_array_x, first_element(data_array, 'x')) - assert_identical(expected_data_array_y, first_element(data_array, 'y')) + assert_identical(expected_data_array_x, first_element(data_array, "x")) + assert_identical(expected_data_array_y, first_element(data_array, "y")) - assert_identical(expected_dataset_x, first_element(dataset, 'x')) - assert_identical(expected_dataset_y, first_element(dataset, 'y')) + assert_identical(expected_dataset_x, first_element(dataset, "x")) + assert_identical(expected_dataset_y, first_element(dataset, "y")) - assert_identical(expected_data_array_x, - first_element(data_array.groupby('y'), 'x')) - assert_identical(expected_dataset_x, - first_element(dataset.groupby('y'), 'x')) + assert_identical(expected_data_array_x, first_element(data_array.groupby("y"), "x")) + assert_identical(expected_dataset_x, first_element(dataset.groupby("y"), "x")) def multiply(*args): val = args[0] @@ -282,51 +286,61 @@ def multiply(*args): # regression test for GH:2341 with pytest.raises(ValueError): - apply_ufunc(multiply, data_array, data_array['y'].values, - input_core_dims=[['y']], output_core_dims=[['y']]) - expected = xr.DataArray(multiply(data_array, data_array['y']), - dims=['x', 'y'], coords=data_array.coords) - actual = apply_ufunc(multiply, data_array, data_array['y'].values, - input_core_dims=[['y'], []], output_core_dims=[['y']]) + apply_ufunc( + multiply, + data_array, + data_array["y"].values, + input_core_dims=[["y"]], + output_core_dims=[["y"]], + ) + expected = xr.DataArray( + multiply(data_array, data_array["y"]), dims=["x", "y"], coords=data_array.coords + ) + actual = apply_ufunc( + multiply, + data_array, + data_array["y"].values, + input_core_dims=[["y"], []], + output_core_dims=[["y"]], + ) assert_identical(expected, actual) def test_apply_output_core_dimension(): - def stack_negative(obj): def func(x): return np.stack([x, -x], axis=-1) - result = apply_ufunc(func, obj, output_core_dims=[['sign']]) + + result = apply_ufunc(func, obj, output_core_dims=[["sign"]]) if isinstance(result, (xr.Dataset, xr.DataArray)): - result.coords['sign'] = [1, -1] + result.coords["sign"] = [1, -1] return result array = np.array([[1, 2], [3, 4]]) - variable = xr.Variable(['x', 'y'], array) - data_array = xr.DataArray(variable, {'x': ['a', 'b'], 'y': [-1, -2]}) - dataset = xr.Dataset({'data': data_array}) + variable = xr.Variable(["x", "y"], array) + data_array = xr.DataArray(variable, {"x": ["a", "b"], "y": [-1, -2]}) + dataset = xr.Dataset({"data": data_array}) stacked_array = np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]) - stacked_variable = xr.Variable(['x', 'y', 'sign'], stacked_array) - stacked_coords = {'x': ['a', 'b'], 'y': [-1, -2], 'sign': [1, -1]} + stacked_variable = xr.Variable(["x", "y", "sign"], stacked_array) + stacked_coords = {"x": ["a", "b"], "y": [-1, -2], "sign": [1, -1]} stacked_data_array = xr.DataArray(stacked_variable, stacked_coords) - stacked_dataset = xr.Dataset({'data': stacked_data_array}) + stacked_dataset = xr.Dataset({"data": stacked_data_array}) assert_identical(stacked_array, stack_negative(array)) assert_identical(stacked_variable, stack_negative(variable)) assert_identical(stacked_data_array, stack_negative(data_array)) assert_identical(stacked_dataset, stack_negative(dataset)) - assert_identical(stacked_data_array, - stack_negative(data_array.groupby('x'))) - assert_identical(stacked_dataset, - stack_negative(dataset.groupby('x'))) + assert_identical(stacked_data_array, stack_negative(data_array.groupby("x"))) + assert_identical(stacked_dataset, stack_negative(dataset.groupby("x"))) def original_and_stack_negative(obj): def func(x): return (x, np.stack([x, -x], axis=-1)) - result = apply_ufunc(func, obj, output_core_dims=[[], ['sign']]) + + result = apply_ufunc(func, obj, output_core_dims=[[], ["sign"]]) if isinstance(result[1], (xr.Dataset, xr.DataArray)): - result[1].coords['sign'] = [1, -1] + result[1].coords["sign"] = [1, -1] return result out0, out1 = original_and_stack_negative(array) @@ -345,24 +359,27 @@ def func(x): assert_identical(dataset, out0) assert_identical(stacked_dataset, out1) - out0, out1 = original_and_stack_negative(data_array.groupby('x')) + out0, out1 = original_and_stack_negative(data_array.groupby("x")) assert_identical(data_array, out0) assert_identical(stacked_data_array, out1) - out0, out1 = original_and_stack_negative(dataset.groupby('x')) + out0, out1 = original_and_stack_negative(dataset.groupby("x")) assert_identical(dataset, out0) assert_identical(stacked_dataset, out1) def test_apply_exclude(): - - def concatenate(objects, dim='x'): + def concatenate(objects, dim="x"): def func(*x): return np.concatenate(x, axis=-1) - result = apply_ufunc(func, *objects, - input_core_dims=[[dim]] * len(objects), - output_core_dims=[[dim]], - exclude_dims={dim}) + + result = apply_ufunc( + func, + *objects, + input_core_dims=[[dim]] * len(objects), + output_core_dims=[[dim]], + exclude_dims={dim} + ) if isinstance(result, (xr.Dataset, xr.DataArray)): # note: this will fail if dim is not a coordinate on any input new_coord = np.concatenate([obj.coords[dim] for obj in objects]) @@ -370,15 +387,17 @@ def func(*x): return result arrays = [np.array([1]), np.array([2, 3])] - variables = [xr.Variable('x', a) for a in arrays] - data_arrays = [xr.DataArray(v, {'x': c, 'y': ('x', range(len(c)))}) - for v, c in zip(variables, [['a'], ['b', 'c']])] - datasets = [xr.Dataset({'data': data_array}) for data_array in data_arrays] + variables = [xr.Variable("x", a) for a in arrays] + data_arrays = [ + xr.DataArray(v, {"x": c, "y": ("x", range(len(c)))}) + for v, c in zip(variables, [["a"], ["b", "c"]]) + ] + datasets = [xr.Dataset({"data": data_array}) for data_array in data_arrays] expected_array = np.array([1, 2, 3]) - expected_variable = xr.Variable('x', expected_array) - expected_data_array = xr.DataArray(expected_variable, [('x', list('abc'))]) - expected_dataset = xr.Dataset({'data': expected_data_array}) + expected_variable = xr.Variable("x", expected_array) + expected_data_array = xr.DataArray(expected_variable, [("x", list("abc"))]) + expected_dataset = xr.Dataset({"data": expected_data_array}) assert_identical(expected_array, concatenate(arrays)) assert_identical(expected_variable, concatenate(variables)) @@ -387,83 +406,79 @@ def func(*x): # must also be a core dimension with pytest.raises(ValueError): - apply_ufunc(identity, variables[0], exclude_dims={'x'}) + apply_ufunc(identity, variables[0], exclude_dims={"x"}) def test_apply_groupby_add(): array = np.arange(5) - variable = xr.Variable('x', array) - coords = {'x': -array, 'y': ('x', [0, 0, 1, 1, 2])} - data_array = xr.DataArray(variable, coords, dims='x') - dataset = xr.Dataset({'z': variable}, coords) - - other_variable = xr.Variable('y', [0, 10]) - other_data_array = xr.DataArray(other_variable, dims='y') - other_dataset = xr.Dataset({'z': other_variable}) - - expected_variable = xr.Variable('x', [0, 1, 12, 13, np.nan]) - expected_data_array = xr.DataArray(expected_variable, coords, dims='x') - expected_dataset = xr.Dataset({'z': expected_variable}, coords) - - assert_identical(expected_data_array, - add(data_array.groupby('y'), other_data_array)) - assert_identical(expected_dataset, - add(data_array.groupby('y'), other_dataset)) - assert_identical(expected_dataset, - add(dataset.groupby('y'), other_data_array)) - assert_identical(expected_dataset, - add(dataset.groupby('y'), other_dataset)) + variable = xr.Variable("x", array) + coords = {"x": -array, "y": ("x", [0, 0, 1, 1, 2])} + data_array = xr.DataArray(variable, coords, dims="x") + dataset = xr.Dataset({"z": variable}, coords) + + other_variable = xr.Variable("y", [0, 10]) + other_data_array = xr.DataArray(other_variable, dims="y") + other_dataset = xr.Dataset({"z": other_variable}) + + expected_variable = xr.Variable("x", [0, 1, 12, 13, np.nan]) + expected_data_array = xr.DataArray(expected_variable, coords, dims="x") + expected_dataset = xr.Dataset({"z": expected_variable}, coords) + + assert_identical( + expected_data_array, add(data_array.groupby("y"), other_data_array) + ) + assert_identical(expected_dataset, add(data_array.groupby("y"), other_dataset)) + assert_identical(expected_dataset, add(dataset.groupby("y"), other_data_array)) + assert_identical(expected_dataset, add(dataset.groupby("y"), other_dataset)) # cannot be performed with xarray.Variable objects that share a dimension with pytest.raises(ValueError): - add(data_array.groupby('y'), other_variable) + add(data_array.groupby("y"), other_variable) # if they are all grouped the same way with pytest.raises(ValueError): - add(data_array.groupby('y'), data_array[:4].groupby('y')) + add(data_array.groupby("y"), data_array[:4].groupby("y")) with pytest.raises(ValueError): - add(data_array.groupby('y'), data_array[1:].groupby('y')) + add(data_array.groupby("y"), data_array[1:].groupby("y")) with pytest.raises(ValueError): - add(data_array.groupby('y'), other_data_array.groupby('y')) + add(data_array.groupby("y"), other_data_array.groupby("y")) with pytest.raises(ValueError): - add(data_array.groupby('y'), data_array.groupby('x')) + add(data_array.groupby("y"), data_array.groupby("x")) def test_unified_dim_sizes(): assert unified_dim_sizes([xr.Variable((), 0)]) == OrderedDict() - assert (unified_dim_sizes([xr.Variable('x', [1]), - xr.Variable('x', [1])]) == - OrderedDict([('x', 1)])) - assert (unified_dim_sizes([xr.Variable('x', [1]), - xr.Variable('y', [1, 2])]) == - OrderedDict([('x', 1), ('y', 2)])) - assert (unified_dim_sizes([xr.Variable(('x', 'z'), [[1]]), - xr.Variable(('y', 'z'), [[1, 2], [3, 4]])], - exclude_dims={'z'}) == - OrderedDict([('x', 1), ('y', 2)])) + assert unified_dim_sizes( + [xr.Variable("x", [1]), xr.Variable("x", [1])] + ) == OrderedDict([("x", 1)]) + assert unified_dim_sizes( + [xr.Variable("x", [1]), xr.Variable("y", [1, 2])] + ) == OrderedDict([("x", 1), ("y", 2)]) + assert unified_dim_sizes( + [xr.Variable(("x", "z"), [[1]]), xr.Variable(("y", "z"), [[1, 2], [3, 4]])], + exclude_dims={"z"}, + ) == OrderedDict([("x", 1), ("y", 2)]) # duplicate dimensions with pytest.raises(ValueError): - unified_dim_sizes([xr.Variable(('x', 'x'), [[1]])]) + unified_dim_sizes([xr.Variable(("x", "x"), [[1]])]) # mismatched lengths with pytest.raises(ValueError): - unified_dim_sizes( - [xr.Variable('x', [1]), xr.Variable('x', [1, 2])]) + unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1, 2])]) def test_broadcast_compat_data_1d(): data = np.arange(5) - var = xr.Variable('x', data) + var = xr.Variable("x", data) - assert_identical(data, broadcast_compat_data(var, ('x',), ())) - assert_identical(data, broadcast_compat_data(var, (), ('x',))) - assert_identical(data[:], broadcast_compat_data(var, ('w',), ('x',))) - assert_identical(data[:, None], - broadcast_compat_data(var, ('w', 'x', 'y'), ())) + assert_identical(data, broadcast_compat_data(var, ("x",), ())) + assert_identical(data, broadcast_compat_data(var, (), ("x",))) + assert_identical(data[:], broadcast_compat_data(var, ("w",), ("x",))) + assert_identical(data[:, None], broadcast_compat_data(var, ("w", "x", "y"), ())) with pytest.raises(ValueError): - broadcast_compat_data(var, ('x',), ('w',)) + broadcast_compat_data(var, ("x",), ("w",)) with pytest.raises(ValueError): broadcast_compat_data(var, (), ()) @@ -471,50 +486,51 @@ def test_broadcast_compat_data_1d(): def test_broadcast_compat_data_2d(): data = np.arange(12).reshape(3, 4) - var = xr.Variable(['x', 'y'], data) - - assert_identical(data, broadcast_compat_data(var, ('x', 'y'), ())) - assert_identical(data, broadcast_compat_data(var, ('x',), ('y',))) - assert_identical(data, broadcast_compat_data(var, (), ('x', 'y'))) - assert_identical(data.T, broadcast_compat_data(var, ('y', 'x'), ())) - assert_identical(data.T, broadcast_compat_data(var, ('y',), ('x',))) - assert_identical(data, broadcast_compat_data(var, ('w', 'x'), ('y',))) - assert_identical(data, broadcast_compat_data(var, ('w',), ('x', 'y'))) - assert_identical(data.T, broadcast_compat_data(var, ('w',), ('y', 'x'))) - assert_identical(data[:, :, None], - broadcast_compat_data(var, ('w', 'x', 'y', 'z'), ())) - assert_identical(data[None, :, :].T, - broadcast_compat_data(var, ('w', 'y', 'x', 'z'), ())) + var = xr.Variable(["x", "y"], data) + + assert_identical(data, broadcast_compat_data(var, ("x", "y"), ())) + assert_identical(data, broadcast_compat_data(var, ("x",), ("y",))) + assert_identical(data, broadcast_compat_data(var, (), ("x", "y"))) + assert_identical(data.T, broadcast_compat_data(var, ("y", "x"), ())) + assert_identical(data.T, broadcast_compat_data(var, ("y",), ("x",))) + assert_identical(data, broadcast_compat_data(var, ("w", "x"), ("y",))) + assert_identical(data, broadcast_compat_data(var, ("w",), ("x", "y"))) + assert_identical(data.T, broadcast_compat_data(var, ("w",), ("y", "x"))) + assert_identical( + data[:, :, None], broadcast_compat_data(var, ("w", "x", "y", "z"), ()) + ) + assert_identical( + data[None, :, :].T, broadcast_compat_data(var, ("w", "y", "x", "z"), ()) + ) def test_keep_attrs(): - def add(a, b, keep_attrs): if keep_attrs: return apply_ufunc(operator.add, a, b, keep_attrs=keep_attrs) else: return apply_ufunc(operator.add, a, b) - a = xr.DataArray([0, 1], [('x', [0, 1])]) - a.attrs['attr'] = 'da' - a['x'].attrs['attr'] = 'da_coord' - b = xr.DataArray([1, 2], [('x', [0, 1])]) + a = xr.DataArray([0, 1], [("x", [0, 1])]) + a.attrs["attr"] = "da" + a["x"].attrs["attr"] = "da_coord" + b = xr.DataArray([1, 2], [("x", [0, 1])]) actual = add(a, b, keep_attrs=False) assert not actual.attrs actual = add(a, b, keep_attrs=True) assert_identical(actual.attrs, a.attrs) - assert_identical(actual['x'].attrs, a['x'].attrs) + assert_identical(actual["x"].attrs, a["x"].attrs) actual = add(a.variable, b.variable, keep_attrs=False) assert not actual.attrs actual = add(a.variable, b.variable, keep_attrs=True) assert_identical(actual.attrs, a.attrs) - a = xr.Dataset({'x': [0, 1]}) - a.attrs['attr'] = 'ds' - a.x.attrs['attr'] = 'da' - b = xr.Dataset({'x': [0, 1]}) + a = xr.Dataset({"x": [0, 1]}) + a.attrs["attr"] = "ds" + a.x.attrs["attr"] = "da" + b = xr.Dataset({"x": [0, 1]}) actual = add(a, b, keep_attrs=False) assert not actual.attrs @@ -524,41 +540,49 @@ def add(a, b, keep_attrs): def test_dataset_join(): - ds0 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) - ds1 = xr.Dataset({'a': ('x', [99, 3]), 'x': [1, 2]}) + ds0 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) # by default, cannot have different labels - with raises_regex(ValueError, 'indexes .* are not equal'): + with raises_regex(ValueError, "indexes .* are not equal"): apply_ufunc(operator.add, ds0, ds1) - with raises_regex(TypeError, 'must supply'): - apply_ufunc(operator.add, ds0, ds1, dataset_join='outer') + with raises_regex(TypeError, "must supply"): + apply_ufunc(operator.add, ds0, ds1, dataset_join="outer") def add(a, b, join, dataset_join): - return apply_ufunc(operator.add, a, b, join=join, - dataset_join=dataset_join, - dataset_fill_value=np.nan) - - actual = add(ds0, ds1, 'outer', 'inner') - expected = xr.Dataset({'a': ('x', [np.nan, 101, np.nan]), - 'x': [0, 1, 2]}) + return apply_ufunc( + operator.add, + a, + b, + join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan, + ) + + actual = add(ds0, ds1, "outer", "inner") + expected = xr.Dataset({"a": ("x", [np.nan, 101, np.nan]), "x": [0, 1, 2]}) assert_identical(actual, expected) - actual = add(ds0, ds1, 'outer', 'outer') + actual = add(ds0, ds1, "outer", "outer") assert_identical(actual, expected) - with raises_regex(ValueError, 'data variable names'): - apply_ufunc(operator.add, ds0, xr.Dataset({'b': 1})) + with raises_regex(ValueError, "data variable names"): + apply_ufunc(operator.add, ds0, xr.Dataset({"b": 1})) - ds2 = xr.Dataset({'b': ('x', [99, 3]), 'x': [1, 2]}) - actual = add(ds0, ds2, 'outer', 'inner') - expected = xr.Dataset({'x': [0, 1, 2]}) + ds2 = xr.Dataset({"b": ("x", [99, 3]), "x": [1, 2]}) + actual = add(ds0, ds2, "outer", "inner") + expected = xr.Dataset({"x": [0, 1, 2]}) assert_identical(actual, expected) # we used np.nan as the fill_value in add() above - actual = add(ds0, ds2, 'outer', 'outer') - expected = xr.Dataset({'a': ('x', [np.nan, np.nan, np.nan]), - 'b': ('x', [np.nan, np.nan, np.nan]), - 'x': [0, 1, 2]}) + actual = add(ds0, ds2, "outer", "outer") + expected = xr.Dataset( + { + "a": ("x", [np.nan, np.nan, np.nan]), + "b": ("x", [np.nan, np.nan, np.nan]), + "x": [0, 1, 2], + } + ) assert_identical(actual, expected) @@ -567,10 +591,10 @@ def test_apply_dask(): import dask.array as da array = da.ones((2,), chunks=2) - variable = xr.Variable('x', array) + variable = xr.Variable("x", array) coords = xr.DataArray(variable).coords.variables - data_array = xr.DataArray(variable, dims=['x'], coords=coords) - dataset = xr.Dataset({'y': variable}) + data_array = xr.DataArray(variable, dims=["x"], coords=coords) + dataset = xr.Dataset({"y": variable}) # encountered dask array, but did not set dask='allowed' with pytest.raises(ValueError): @@ -584,10 +608,10 @@ def test_apply_dask(): # unknown setting for dask array handling with pytest.raises(ValueError): - apply_ufunc(identity, array, dask='unknown') + apply_ufunc(identity, array, dask="unknown") def dask_safe_identity(x): - return apply_ufunc(identity, x, dask='allowed') + return apply_ufunc(identity, x, dask="allowed") assert array is dask_safe_identity(array) @@ -600,7 +624,7 @@ def dask_safe_identity(x): assert_identical(data_array, actual) actual = dask_safe_identity(dataset) - assert isinstance(actual['y'].data, da.Array) + assert isinstance(actual["y"].data, da.Array) assert_identical(dataset, actual) @@ -609,11 +633,10 @@ def test_apply_dask_parallelized_one_arg(): import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) - data_array = xr.DataArray(array, dims=('x', 'y')) + data_array = xr.DataArray(array, dims=("x", "y")) def parallel_identity(x): - return apply_ufunc(identity, x, dask='parallelized', - output_dtypes=[x.dtype]) + return apply_ufunc(identity, x, dask="parallelized", output_dtypes=[x.dtype]) actual = parallel_identity(data_array) assert isinstance(actual.data, da.Array) @@ -630,13 +653,13 @@ def test_apply_dask_parallelized_two_args(): import dask.array as da array = da.ones((2, 2), chunks=(1, 1), dtype=np.int64) - data_array = xr.DataArray(array, dims=('x', 'y')) + data_array = xr.DataArray(array, dims=("x", "y")) data_array.name = None def parallel_add(x, y): - return apply_ufunc(operator.add, x, y, - dask='parallelized', - output_dtypes=[np.int64]) + return apply_ufunc( + operator.add, x, y, dask="parallelized", output_dtypes=[np.int64] + ) def check(x, y): actual = parallel_add(x, y) @@ -658,61 +681,84 @@ def test_apply_dask_parallelized_errors(): import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) - data_array = xr.DataArray(array, dims=('x', 'y')) + data_array = xr.DataArray(array, dims=("x", "y")) with pytest.raises(NotImplementedError): - apply_ufunc(identity, data_array, output_core_dims=[['z'], ['z']], - dask='parallelized') - with raises_regex(ValueError, 'dtypes'): - apply_ufunc(identity, data_array, dask='parallelized') - with raises_regex(TypeError, 'list'): - apply_ufunc(identity, data_array, dask='parallelized', - output_dtypes=float) - with raises_regex(ValueError, 'must have the same length'): - apply_ufunc(identity, data_array, dask='parallelized', - output_dtypes=[float, float]) - with raises_regex(ValueError, 'output_sizes'): - apply_ufunc(identity, data_array, output_core_dims=[['z']], - output_dtypes=[float], dask='parallelized') - with raises_regex(ValueError, 'at least one input is an xarray object'): - apply_ufunc(identity, array, dask='parallelized') - - with raises_regex(ValueError, 'consists of multiple chunks'): - apply_ufunc(identity, data_array, dask='parallelized', - output_dtypes=[float], - input_core_dims=[('y',)], - output_core_dims=[('y',)]) + apply_ufunc( + identity, data_array, output_core_dims=[["z"], ["z"]], dask="parallelized" + ) + with raises_regex(ValueError, "dtypes"): + apply_ufunc(identity, data_array, dask="parallelized") + with raises_regex(TypeError, "list"): + apply_ufunc(identity, data_array, dask="parallelized", output_dtypes=float) + with raises_regex(ValueError, "must have the same length"): + apply_ufunc( + identity, data_array, dask="parallelized", output_dtypes=[float, float] + ) + with raises_regex(ValueError, "output_sizes"): + apply_ufunc( + identity, + data_array, + output_core_dims=[["z"]], + output_dtypes=[float], + dask="parallelized", + ) + with raises_regex(ValueError, "at least one input is an xarray object"): + apply_ufunc(identity, array, dask="parallelized") + + with raises_regex(ValueError, "consists of multiple chunks"): + apply_ufunc( + identity, + data_array, + dask="parallelized", + output_dtypes=[float], + input_core_dims=[("y",)], + output_core_dims=[("y",)], + ) # it's currently impossible to silence these warnings from inside dask.array: # https://github.com/dask/dask/issues/3245 @requires_dask -@pytest.mark.filterwarnings('ignore:Mean of empty slice') +@pytest.mark.filterwarnings("ignore:Mean of empty slice") def test_apply_dask_multiple_inputs(): import dask.array as da def covariance(x, y): - return ((x - x.mean(axis=-1, keepdims=True)) * - (y - y.mean(axis=-1, keepdims=True))).mean(axis=-1) + return ( + (x - x.mean(axis=-1, keepdims=True)) * (y - y.mean(axis=-1, keepdims=True)) + ).mean(axis=-1) rs = np.random.RandomState(42) array1 = da.from_array(rs.randn(4, 4), chunks=(2, 4)) array2 = da.from_array(rs.randn(4, 4), chunks=(2, 4)) - data_array_1 = xr.DataArray(array1, dims=('x', 'z')) - data_array_2 = xr.DataArray(array2, dims=('y', 'z')) + data_array_1 = xr.DataArray(array1, dims=("x", "z")) + data_array_2 = xr.DataArray(array2, dims=("y", "z")) expected = apply_ufunc( - covariance, data_array_1.compute(), data_array_2.compute(), - input_core_dims=[['z'], ['z']]) + covariance, + data_array_1.compute(), + data_array_2.compute(), + input_core_dims=[["z"], ["z"]], + ) allowed = apply_ufunc( - covariance, data_array_1, data_array_2, input_core_dims=[['z'], ['z']], - dask='allowed') + covariance, + data_array_1, + data_array_2, + input_core_dims=[["z"], ["z"]], + dask="allowed", + ) assert isinstance(allowed.data, da.Array) xr.testing.assert_allclose(expected, allowed.compute()) parallelized = apply_ufunc( - covariance, data_array_1, data_array_2, input_core_dims=[['z'], ['z']], - dask='parallelized', output_dtypes=[float]) + covariance, + data_array_1, + data_array_2, + input_core_dims=[["z"], ["z"]], + dask="parallelized", + output_dtypes=[float], + ) assert isinstance(parallelized.data, da.Array) xr.testing.assert_allclose(expected, parallelized.compute()) @@ -722,19 +768,25 @@ def test_apply_dask_new_output_dimension(): import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) - data_array = xr.DataArray(array, dims=('x', 'y')) + data_array = xr.DataArray(array, dims=("x", "y")) def stack_negative(obj): def func(x): return np.stack([x, -x], axis=-1) - return apply_ufunc(func, obj, output_core_dims=[['sign']], - dask='parallelized', output_dtypes=[obj.dtype], - output_sizes={'sign': 2}) + + return apply_ufunc( + func, + obj, + output_core_dims=[["sign"]], + dask="parallelized", + output_dtypes=[obj.dtype], + output_sizes={"sign": 2}, + ) expected = stack_negative(data_array.compute()) actual = stack_negative(data_array) - assert actual.dims == ('x', 'y', 'sign') + assert actual.dims == ("x", "y", "sign") assert actual.shape == (2, 2, 2) assert isinstance(actual.data, da.Array) assert_identical(expected, actual) @@ -745,28 +797,31 @@ def pandas_median(x): def test_vectorize(): - data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=('x', 'y')) - expected = xr.DataArray([1, 2], dims=['x']) - actual = apply_ufunc(pandas_median, data_array, - input_core_dims=[['y']], - vectorize=True) + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + actual = apply_ufunc( + pandas_median, data_array, input_core_dims=[["y"]], vectorize=True + ) assert_identical(expected, actual) @requires_dask def test_vectorize_dask(): - data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=('x', 'y')) - expected = xr.DataArray([1, 2], dims=['x']) - actual = apply_ufunc(pandas_median, data_array.chunk({'x': 1}), - input_core_dims=[['y']], - vectorize=True, - dask='parallelized', - output_dtypes=[float]) + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + actual = apply_ufunc( + pandas_median, + data_array.chunk({"x": 1}), + input_core_dims=[["y"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) assert_identical(expected, actual) def test_output_wrong_number(): - variable = xr.Variable('x', np.arange(10)) + variable = xr.Variable("x", np.arange(10)) def identity(x): return x @@ -774,15 +829,15 @@ def identity(x): def tuple3x(x): return (x, x, x) - with raises_regex(ValueError, 'number of outputs'): + with raises_regex(ValueError, "number of outputs"): apply_ufunc(identity, variable, output_core_dims=[(), ()]) - with raises_regex(ValueError, 'number of outputs'): + with raises_regex(ValueError, "number of outputs"): apply_ufunc(tuple3x, variable, output_core_dims=[(), ()]) def test_output_wrong_dims(): - variable = xr.Variable('x', np.arange(10)) + variable = xr.Variable("x", np.arange(10)) def add_dim(x): return x[..., np.newaxis] @@ -790,21 +845,21 @@ def add_dim(x): def remove_dim(x): return x[..., 0] - with raises_regex(ValueError, 'unexpected number of dimensions'): - apply_ufunc(add_dim, variable, output_core_dims=[('y', 'z')]) + with raises_regex(ValueError, "unexpected number of dimensions"): + apply_ufunc(add_dim, variable, output_core_dims=[("y", "z")]) - with raises_regex(ValueError, 'unexpected number of dimensions'): + with raises_regex(ValueError, "unexpected number of dimensions"): apply_ufunc(add_dim, variable) - with raises_regex(ValueError, 'unexpected number of dimensions'): + with raises_regex(ValueError, "unexpected number of dimensions"): apply_ufunc(remove_dim, variable) def test_output_wrong_dim_size(): array = np.arange(10) - variable = xr.Variable('x', array) - data_array = xr.DataArray(variable, [('x', -array)]) - dataset = xr.Dataset({'y': variable}, {'x': -array}) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) def truncate(array): return array[:5] @@ -812,79 +867,85 @@ def truncate(array): def apply_truncate_broadcast_invalid(obj): return apply_ufunc(truncate, obj) - with raises_regex(ValueError, 'size of dimension'): + with raises_regex(ValueError, "size of dimension"): apply_truncate_broadcast_invalid(variable) - with raises_regex(ValueError, 'size of dimension'): + with raises_regex(ValueError, "size of dimension"): apply_truncate_broadcast_invalid(data_array) - with raises_regex(ValueError, 'size of dimension'): + with raises_regex(ValueError, "size of dimension"): apply_truncate_broadcast_invalid(dataset) def apply_truncate_x_x_invalid(obj): - return apply_ufunc(truncate, obj, input_core_dims=[['x']], - output_core_dims=[['x']]) + return apply_ufunc( + truncate, obj, input_core_dims=[["x"]], output_core_dims=[["x"]] + ) - with raises_regex(ValueError, 'size of dimension'): + with raises_regex(ValueError, "size of dimension"): apply_truncate_x_x_invalid(variable) - with raises_regex(ValueError, 'size of dimension'): + with raises_regex(ValueError, "size of dimension"): apply_truncate_x_x_invalid(data_array) - with raises_regex(ValueError, 'size of dimension'): + with raises_regex(ValueError, "size of dimension"): apply_truncate_x_x_invalid(dataset) def apply_truncate_x_z(obj): - return apply_ufunc(truncate, obj, input_core_dims=[['x']], - output_core_dims=[['z']]) + return apply_ufunc( + truncate, obj, input_core_dims=[["x"]], output_core_dims=[["z"]] + ) - assert_identical(xr.Variable('z', array[:5]), - apply_truncate_x_z(variable)) - assert_identical(xr.DataArray(array[:5], dims=['z']), - apply_truncate_x_z(data_array)) - assert_identical(xr.Dataset({'y': ('z', array[:5])}), - apply_truncate_x_z(dataset)) + assert_identical(xr.Variable("z", array[:5]), apply_truncate_x_z(variable)) + assert_identical( + xr.DataArray(array[:5], dims=["z"]), apply_truncate_x_z(data_array) + ) + assert_identical(xr.Dataset({"y": ("z", array[:5])}), apply_truncate_x_z(dataset)) def apply_truncate_x_x_valid(obj): - return apply_ufunc(truncate, obj, input_core_dims=[['x']], - output_core_dims=[['x']], exclude_dims={'x'}) - - assert_identical(xr.Variable('x', array[:5]), - apply_truncate_x_x_valid(variable)) - assert_identical(xr.DataArray(array[:5], dims=['x']), - apply_truncate_x_x_valid(data_array)) - assert_identical(xr.Dataset({'y': ('x', array[:5])}), - apply_truncate_x_x_valid(dataset)) - - -@pytest.mark.parametrize('use_dask', [True, False]) + return apply_ufunc( + truncate, + obj, + input_core_dims=[["x"]], + output_core_dims=[["x"]], + exclude_dims={"x"}, + ) + + assert_identical(xr.Variable("x", array[:5]), apply_truncate_x_x_valid(variable)) + assert_identical( + xr.DataArray(array[:5], dims=["x"]), apply_truncate_x_x_valid(data_array) + ) + assert_identical( + xr.Dataset({"y": ("x", array[:5])}), apply_truncate_x_x_valid(dataset) + ) + + +@pytest.mark.parametrize("use_dask", [True, False]) def test_dot(use_dask): if use_dask: if not has_dask: - pytest.skip('test for dask.') + pytest.skip("test for dask.") a = np.arange(30 * 4).reshape(30, 4) b = np.arange(30 * 4 * 5).reshape(30, 4, 5) c = np.arange(5 * 60).reshape(5, 60) - da_a = xr.DataArray(a, dims=['a', 'b'], - coords={'a': np.linspace(0, 1, 30)}) - da_b = xr.DataArray(b, dims=['a', 'b', 'c'], - coords={'a': np.linspace(0, 1, 30)}) - da_c = xr.DataArray(c, dims=['c', 'e']) + da_a = xr.DataArray(a, dims=["a", "b"], coords={"a": np.linspace(0, 1, 30)}) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords={"a": np.linspace(0, 1, 30)}) + da_c = xr.DataArray(c, dims=["c", "e"]) if use_dask: - da_a = da_a.chunk({'a': 3}) - da_b = da_b.chunk({'a': 3}) - da_c = da_c.chunk({'c': 3}) + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + da_c = da_c.chunk({"c": 3}) - actual = xr.dot(da_a, da_b, dims=['a', 'b']) - assert actual.dims == ('c', ) - assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + actual = xr.dot(da_a, da_b, dims=["a", "b"]) + assert actual.dims == ("c",) + assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) actual = xr.dot(da_a, da_b) - assert actual.dims == ('c', ) - assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert actual.dims == ("c",) + assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) if use_dask: import dask - if LooseVersion(dask.__version__) < LooseVersion('0.17.3'): + + if LooseVersion(dask.__version__) < LooseVersion("0.17.3"): pytest.skip("needs dask.array.einsum") # for only a single array is passed without dims argument, just return @@ -894,78 +955,78 @@ def test_dot(use_dask): # test for variable actual = xr.dot(da_a.variable, da_b.variable) - assert actual.dims == ('c', ) - assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert actual.dims == ("c",) + assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() assert isinstance(actual.data, type(da_a.variable.data)) if use_dask: - da_a = da_a.chunk({'a': 3}) - da_b = da_b.chunk({'a': 3}) - actual = xr.dot(da_a, da_b, dims=['b']) - assert actual.dims == ('a', 'c') - assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + actual = xr.dot(da_a, da_b, dims=["b"]) + assert actual.dims == ("a", "c") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) - actual = xr.dot(da_a, da_b, dims=['b']) - assert actual.dims == ('a', 'c') - assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + actual = xr.dot(da_a, da_b, dims=["b"]) + assert actual.dims == ("a", "c") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims='b') - assert actual.dims == ('a', 'c') - assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + actual = xr.dot(da_a, da_b, dims="b") + assert actual.dims == ("a", "c") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims='a') - assert actual.dims == ('b', 'c') - assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all() + actual = xr.dot(da_a, da_b, dims="a") + assert actual.dims == ("b", "c") + assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all() - actual = xr.dot(da_a, da_b, dims='c') - assert actual.dims == ('a', 'b') - assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + actual = xr.dot(da_a, da_b, dims="c") + assert actual.dims == ("a", "b") + assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=['a', 'b']) - assert actual.dims == ('c', 'e') - assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all() + actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"]) + assert actual.dims == ("c", "e") + assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all() # should work with tuple - actual = xr.dot(da_a, da_b, dims=('c', )) - assert actual.dims == ('a', 'b') - assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + actual = xr.dot(da_a, da_b, dims=("c",)) + assert actual.dims == ("a", "b") + assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() # default dims actual = xr.dot(da_a, da_b, da_c) - assert actual.dims == ('e', ) - assert (actual.data == np.einsum('ij,ijk,kl->l ', a, b, c)).all() + assert actual.dims == ("e",) + assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all() # 1 array summation - actual = xr.dot(da_a, dims='a') - assert actual.dims == ('b', ) - assert (actual.data == np.einsum('ij->j ', a)).all() + actual = xr.dot(da_a, dims="a") + assert actual.dims == ("b",) + assert (actual.data == np.einsum("ij->j ", a)).all() # empty dim - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims='a') - assert actual.dims == ('b', ) + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a") + assert actual.dims == ("b",) assert (actual.data == np.zeros(actual.shape)).all() # Invalid cases - if not use_dask or LooseVersion(dask.__version__) > LooseVersion('0.17.4'): + if not use_dask or LooseVersion(dask.__version__) > LooseVersion("0.17.4"): with pytest.raises(TypeError): - xr.dot(da_a, dims='a', invalid=None) + xr.dot(da_a, dims="a", invalid=None) with pytest.raises(TypeError): - xr.dot(da_a.to_dataset(name='da'), dims='a') + xr.dot(da_a.to_dataset(name="da"), dims="a") with pytest.raises(TypeError): - xr.dot(dims='a') + xr.dot(dims="a") # einsum parameters - actual = xr.dot(da_a, da_b, dims=['b'], order='C') - assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() - assert actual.values.flags['C_CONTIGUOUS'] - assert not actual.values.flags['F_CONTIGUOUS'] - actual = xr.dot(da_a, da_b, dims=['b'], order='F') - assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + actual = xr.dot(da_a, da_b, dims=["b"], order="C") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() + assert actual.values.flags["C_CONTIGUOUS"] + assert not actual.values.flags["F_CONTIGUOUS"] + actual = xr.dot(da_a, da_b, dims=["b"], order="F") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() # dask converts Fortran arrays to C order when merging the final array if not use_dask: - assert not actual.values.flags['C_CONTIGUOUS'] - assert actual.values.flags['F_CONTIGUOUS'] + assert not actual.values.flags["C_CONTIGUOUS"] + assert actual.values.flags["F_CONTIGUOUS"] # einsum has a constant string as of the first parameter, which makes # it hard to pass to xarray.apply_ufunc. @@ -975,7 +1036,7 @@ def test_dot(use_dask): def test_where(): - cond = xr.DataArray([True, False], dims='x') + cond = xr.DataArray([True, False], dims="x") actual = xr.where(cond, 1, 0) - expected = xr.DataArray([1, 0], dims='x') + expected = xr.DataArray([1, 0], dims="x") assert_identical(expected, actual) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index ff188305c83..d16ebeeb53d 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -7,8 +7,13 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes from . import ( - InaccessibleArray, assert_array_equal, - assert_equal, assert_identical, raises_regex, requires_dask) + InaccessibleArray, + assert_array_equal, + assert_equal, + assert_identical, + raises_regex, + requires_dask, +) from .test_dataset import create_test_data @@ -19,309 +24,323 @@ def test_concat(self): # drop the third dimension to keep things relatively understandable data = create_test_data() for k in list(data.variables): - if 'dim3' in data[k].dims: + if "dim3" in data[k].dims: del data[k] - split_data = [data.isel(dim1=slice(3)), - data.isel(dim1=slice(3, None))] - assert_identical(data, concat(split_data, 'dim1')) + split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] + assert_identical(data, concat(split_data, "dim1")) def rectify_dim_order(dataset): # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` return Dataset( - { - k: v.transpose(*data[k].dims) - for k, v in dataset.data_vars.items() - }, + {k: v.transpose(*data[k].dims) for k, v in dataset.data_vars.items()}, dataset.coords, - attrs=dataset.attrs + attrs=dataset.attrs, ) - for dim in ['dim1', 'dim2']: + for dim in ["dim1", "dim2"]: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim)) - dim = 'dim2' - assert_identical( - data, concat(datasets, data[dim])) - assert_identical( - data, concat(datasets, data[dim], coords='minimal')) + dim = "dim2" + assert_identical(data, concat(datasets, data[dim])) + assert_identical(data, concat(datasets, data[dim], coords="minimal")) datasets = [g for _, g in data.groupby(dim, squeeze=True)] - concat_over = [k for k, v in data.coords.items() - if dim in v.dims and k != dim] + concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] actual = concat(datasets, data[dim], coords=concat_over) assert_identical(data, rectify_dim_order(actual)) - actual = concat(datasets, data[dim], coords='different') + actual = concat(datasets, data[dim], coords="different") assert_identical(data, rectify_dim_order(actual)) # make sure the coords argument behaves as expected - data.coords['extra'] = ('dim4', np.arange(3)) - for dim in ['dim1', 'dim2']: + data.coords["extra"] = ("dim4", np.arange(3)) + for dim in ["dim1", "dim2"]: datasets = [g for _, g in data.groupby(dim, squeeze=True)] - actual = concat(datasets, data[dim], coords='all') - expected = np.array([data['extra'].values - for _ in range(data.dims[dim])]) - assert_array_equal(actual['extra'].values, expected) + actual = concat(datasets, data[dim], coords="all") + expected = np.array([data["extra"].values for _ in range(data.dims[dim])]) + assert_array_equal(actual["extra"].values, expected) - actual = concat(datasets, data[dim], coords='different') - assert_equal(data['extra'], actual['extra']) - actual = concat(datasets, data[dim], coords='minimal') - assert_equal(data['extra'], actual['extra']) + actual = concat(datasets, data[dim], coords="different") + assert_equal(data["extra"], actual["extra"]) + actual = concat(datasets, data[dim], coords="minimal") + assert_equal(data["extra"], actual["extra"]) # verify that the dim argument takes precedence over # concatenating dataset variables of the same name - dim = (2 * data['dim1']).rename('dim1') - datasets = [g for _, g in data.groupby('dim1', squeeze=False)] + dim = (2 * data["dim1"]).rename("dim1") + datasets = [g for _, g in data.groupby("dim1", squeeze=False)] expected = data.copy() - expected['dim1'] = dim + expected["dim1"] = dim assert_identical(expected, concat(datasets, dim)) def test_concat_data_vars(self): - data = Dataset({'foo': ('x', np.random.randn(10))}) + data = Dataset({"foo": ("x", np.random.randn(10))}) objs = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] - for data_vars in ['minimal', 'different', 'all', [], ['foo']]: - actual = concat(objs, dim='x', data_vars=data_vars) + for data_vars in ["minimal", "different", "all", [], ["foo"]]: + actual = concat(objs, dim="x", data_vars=data_vars) assert_identical(data, actual) def test_concat_coords(self): - data = Dataset({'foo': ('x', np.random.randn(10))}) - expected = data.assign_coords(c=('x', [0] * 5 + [1] * 5)) - objs = [data.isel(x=slice(5)).assign_coords(c=0), - data.isel(x=slice(5, None)).assign_coords(c=1)] - for coords in ['different', 'all', ['c']]: - actual = concat(objs, dim='x', coords=coords) + data = Dataset({"foo": ("x", np.random.randn(10))}) + expected = data.assign_coords(c=("x", [0] * 5 + [1] * 5)) + objs = [ + data.isel(x=slice(5)).assign_coords(c=0), + data.isel(x=slice(5, None)).assign_coords(c=1), + ] + for coords in ["different", "all", ["c"]]: + actual = concat(objs, dim="x", coords=coords) assert_identical(expected, actual) - for coords in ['minimal', []]: - with raises_regex(ValueError, 'not equal across'): - concat(objs, dim='x', coords=coords) + for coords in ["minimal", []]: + with raises_regex(ValueError, "not equal across"): + concat(objs, dim="x", coords=coords) def test_concat_constant_index(self): # GH425 - ds1 = Dataset({'foo': 1.5}, {'y': 1}) - ds2 = Dataset({'foo': 2.5}, {'y': 1}) - expected = Dataset({'foo': ('y', [1.5, 2.5]), 'y': [1, 1]}) - for mode in ['different', 'all', ['foo']]: - actual = concat([ds1, ds2], 'y', data_vars=mode) + ds1 = Dataset({"foo": 1.5}, {"y": 1}) + ds2 = Dataset({"foo": 2.5}, {"y": 1}) + expected = Dataset({"foo": ("y", [1.5, 2.5]), "y": [1, 1]}) + for mode in ["different", "all", ["foo"]]: + actual = concat([ds1, ds2], "y", data_vars=mode) assert_identical(expected, actual) - with raises_regex(ValueError, 'not equal across datasets'): - concat([ds1, ds2], 'y', data_vars='minimal') + with raises_regex(ValueError, "not equal across datasets"): + concat([ds1, ds2], "y", data_vars="minimal") def test_concat_size0(self): data = create_test_data() split_data = [data.isel(dim1=slice(0, 0)), data] - actual = concat(split_data, 'dim1') + actual = concat(split_data, "dim1") assert_identical(data, actual) - actual = concat(split_data[::-1], 'dim1') + actual = concat(split_data[::-1], "dim1") assert_identical(data, actual) def test_concat_autoalign(self): - ds1 = Dataset({'foo': DataArray([1, 2], coords=[('x', [1, 2])])}) - ds2 = Dataset({'foo': DataArray([1, 2], coords=[('x', [1, 3])])}) - actual = concat([ds1, ds2], 'y') - expected = Dataset({'foo': DataArray([[1, 2, np.nan], [1, np.nan, 2]], - dims=['y', 'x'], - coords={'x': [1, 2, 3]})}) + ds1 = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 2])])}) + ds2 = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 3])])}) + actual = concat([ds1, ds2], "y") + expected = Dataset( + { + "foo": DataArray( + [[1, 2, np.nan], [1, np.nan, 2]], + dims=["y", "x"], + coords={"x": [1, 2, 3]}, + ) + } + ) assert_identical(expected, actual) def test_concat_errors(self): data = create_test_data() - split_data = [data.isel(dim1=slice(3)), - data.isel(dim1=slice(3, None))] + split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] - with raises_regex(ValueError, 'must supply at least one'): - concat([], 'dim1') + with raises_regex(ValueError, "must supply at least one"): + concat([], "dim1") - with raises_regex(ValueError, 'are not coordinates'): - concat([data, data], 'new_dim', coords=['not_found']) + with raises_regex(ValueError, "are not coordinates"): + concat([data, data], "new_dim", coords=["not_found"]) - with raises_regex(ValueError, 'global attributes not'): + with raises_regex(ValueError, "global attributes not"): data0, data1 = deepcopy(split_data) - data1.attrs['foo'] = 'bar' - concat([data0, data1], 'dim1', compat='identical') - assert_identical( - data, concat([data0, data1], 'dim1', compat='equals')) + data1.attrs["foo"] = "bar" + concat([data0, data1], "dim1", compat="identical") + assert_identical(data, concat([data0, data1], "dim1", compat="equals")) - with raises_regex(ValueError, 'encountered unexpected'): + with raises_regex(ValueError, "encountered unexpected"): data0, data1 = deepcopy(split_data) - data1['foo'] = ('bar', np.random.randn(10)) - concat([data0, data1], 'dim1') + data1["foo"] = ("bar", np.random.randn(10)) + concat([data0, data1], "dim1") - with raises_regex(ValueError, 'compat.* invalid'): - concat(split_data, 'dim1', compat='foobar') + with raises_regex(ValueError, "compat.* invalid"): + concat(split_data, "dim1", compat="foobar") - with raises_regex(ValueError, 'unexpected value for'): - concat([data, data], 'new_dim', coords='foobar') + with raises_regex(ValueError, "unexpected value for"): + concat([data, data], "new_dim", coords="foobar") - with raises_regex( - ValueError, 'coordinate in some datasets but not others'): - concat([Dataset({'x': 0}), Dataset({'x': [1]})], dim='z') + with raises_regex(ValueError, "coordinate in some datasets but not others"): + concat([Dataset({"x": 0}), Dataset({"x": [1]})], dim="z") - with raises_regex( - ValueError, 'coordinate in some datasets but not others'): - concat([Dataset({'x': 0}), Dataset({}, {'x': 1})], dim='z') + with raises_regex(ValueError, "coordinate in some datasets but not others"): + concat([Dataset({"x": 0}), Dataset({}, {"x": 1})], dim="z") - with raises_regex(ValueError, 'no longer a valid'): - concat([data, data], 'new_dim', mode='different') - with raises_regex(ValueError, 'no longer a valid'): - concat([data, data], 'new_dim', concat_over='different') + with raises_regex(ValueError, "no longer a valid"): + concat([data, data], "new_dim", mode="different") + with raises_regex(ValueError, "no longer a valid"): + concat([data, data], "new_dim", concat_over="different") def test_concat_join_kwarg(self): - ds1 = Dataset({'a': (('x', 'y'), [[0]])}, - coords={'x': [0], 'y': [0]}) - ds2 = Dataset({'a': (('x', 'y'), [[0]])}, - coords={'x': [1], 'y': [0.0001]}) + ds1 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]}) + ds2 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]}) expected = dict() - expected['outer'] = Dataset({'a': (('x', 'y'), - [[0, np.nan], [np.nan, 0]])}, - {'x': [0, 1], 'y': [0, 0.0001]}) - expected['inner'] = Dataset({'a': (('x', 'y'), [[], []])}, - {'x': [0, 1], 'y': []}) - expected['left'] = Dataset({'a': (('x', 'y'), - np.array([0, np.nan], ndmin=2).T)}, - coords={'x': [0, 1], 'y': [0]}) - expected['right'] = Dataset({'a': (('x', 'y'), - np.array([np.nan, 0], ndmin=2).T)}, - coords={'x': [0, 1], 'y': [0.0001]}) + expected["outer"] = Dataset( + {"a": (("x", "y"), [[0, np.nan], [np.nan, 0]])}, + {"x": [0, 1], "y": [0, 0.0001]}, + ) + expected["inner"] = Dataset( + {"a": (("x", "y"), [[], []])}, {"x": [0, 1], "y": []} + ) + expected["left"] = Dataset( + {"a": (("x", "y"), np.array([0, np.nan], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0]}, + ) + expected["right"] = Dataset( + {"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0.0001]}, + ) with raises_regex(ValueError, "indexes along dimension 'y'"): - actual = concat([ds1, ds2], join='exact', dim='x') + actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: - actual = concat([ds1, ds2], join=join, dim='x') + actual = concat([ds1, ds2], join=join, dim="x") assert_equal(actual, expected[join]) def test_concat_promote_shape(self): # mixed dims within variables - objs = [Dataset({}, {'x': 0}), Dataset({'x': [1]})] - actual = concat(objs, 'x') - expected = Dataset({'x': [0, 1]}) + objs = [Dataset({}, {"x": 0}), Dataset({"x": [1]})] + actual = concat(objs, "x") + expected = Dataset({"x": [0, 1]}) assert_identical(actual, expected) - objs = [Dataset({'x': [0]}), Dataset({}, {'x': 1})] - actual = concat(objs, 'x') + objs = [Dataset({"x": [0]}), Dataset({}, {"x": 1})] + actual = concat(objs, "x") assert_identical(actual, expected) # mixed dims between variables - objs = [Dataset({'x': [2], 'y': 3}), Dataset({'x': [4], 'y': 5})] - actual = concat(objs, 'x') - expected = Dataset({'x': [2, 4], 'y': ('x', [3, 5])}) + objs = [Dataset({"x": [2], "y": 3}), Dataset({"x": [4], "y": 5})] + actual = concat(objs, "x") + expected = Dataset({"x": [2, 4], "y": ("x", [3, 5])}) assert_identical(actual, expected) # mixed dims in coord variable - objs = [Dataset({'x': [0]}, {'y': -1}), - Dataset({'x': [1]}, {'y': ('x', [-2])})] - actual = concat(objs, 'x') - expected = Dataset({'x': [0, 1]}, {'y': ('x', [-1, -2])}) + objs = [Dataset({"x": [0]}, {"y": -1}), Dataset({"x": [1]}, {"y": ("x", [-2])})] + actual = concat(objs, "x") + expected = Dataset({"x": [0, 1]}, {"y": ("x", [-1, -2])}) assert_identical(actual, expected) # scalars with mixed lengths along concat dim -- values should repeat - objs = [Dataset({'x': [0]}, {'y': -1}), - Dataset({'x': [1, 2]}, {'y': -2})] - actual = concat(objs, 'x') - expected = Dataset({'x': [0, 1, 2]}, {'y': ('x', [-1, -2, -2])}) + objs = [Dataset({"x": [0]}, {"y": -1}), Dataset({"x": [1, 2]}, {"y": -2})] + actual = concat(objs, "x") + expected = Dataset({"x": [0, 1, 2]}, {"y": ("x", [-1, -2, -2])}) assert_identical(actual, expected) # broadcast 1d x 1d -> 2d - objs = [Dataset({'z': ('x', [-1])}, {'x': [0], 'y': [0]}), - Dataset({'z': ('y', [1])}, {'x': [1], 'y': [0]})] - actual = concat(objs, 'x') - expected = Dataset({'z': (('x', 'y'), [[-1], [1]])}, - {'x': [0, 1], 'y': [0]}) + objs = [ + Dataset({"z": ("x", [-1])}, {"x": [0], "y": [0]}), + Dataset({"z": ("y", [1])}, {"x": [1], "y": [0]}), + ] + actual = concat(objs, "x") + expected = Dataset({"z": (("x", "y"), [[-1], [1]])}, {"x": [0, 1], "y": [0]}) assert_identical(actual, expected) def test_concat_do_not_promote(self): # GH438 - objs = [Dataset({'y': ('t', [1])}, {'x': 1, 't': [0]}), - Dataset({'y': ('t', [2])}, {'x': 1, 't': [0]})] - expected = Dataset({'y': ('t', [1, 2])}, {'x': 1, 't': [0, 0]}) - actual = concat(objs, 't') + objs = [ + Dataset({"y": ("t", [1])}, {"x": 1, "t": [0]}), + Dataset({"y": ("t", [2])}, {"x": 1, "t": [0]}), + ] + expected = Dataset({"y": ("t", [1, 2])}, {"x": 1, "t": [0, 0]}) + actual = concat(objs, "t") assert_identical(expected, actual) - objs = [Dataset({'y': ('t', [1])}, {'x': 1, 't': [0]}), - Dataset({'y': ('t', [2])}, {'x': 2, 't': [0]})] + objs = [ + Dataset({"y": ("t", [1])}, {"x": 1, "t": [0]}), + Dataset({"y": ("t", [2])}, {"x": 2, "t": [0]}), + ] with pytest.raises(ValueError): - concat(objs, 't', coords='minimal') + concat(objs, "t", coords="minimal") def test_concat_dim_is_variable(self): - objs = [Dataset({'x': 0}), Dataset({'x': 1})] - coord = Variable('y', [3, 4]) - expected = Dataset({'x': ('y', [0, 1]), 'y': [3, 4]}) + objs = [Dataset({"x": 0}), Dataset({"x": 1})] + coord = Variable("y", [3, 4]) + expected = Dataset({"x": ("y", [0, 1]), "y": [3, 4]}) actual = concat(objs, coord) assert_identical(actual, expected) def test_concat_multiindex(self): - x = pd.MultiIndex.from_product([[1, 2, 3], ['a', 'b']]) - expected = Dataset({'x': x}) - actual = concat([expected.isel(x=slice(2)), - expected.isel(x=slice(2, None))], 'x') + x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) + expected = Dataset({"x": x}) + actual = concat( + [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" + ) assert expected.equals(actual) assert isinstance(actual.x.to_index(), pd.MultiIndex) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_concat_fill_value(self, fill_value): - datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), - Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + datasets = [ + Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + ] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = Dataset({'a': (('t', 'x'), - [[fill_value, 2, 3], - [1, 2, fill_value]])}, - {'x': [0, 1, 2]}) - actual = concat(datasets, dim='t', fill_value=fill_value) + expected = Dataset( + {"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])}, + {"x": [0, 1, 2]}, + ) + actual = concat(datasets, dim="t", fill_value=fill_value) assert_identical(actual, expected) class TestConcatDataArray: def test_concat(self): - ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), - 'bar': (['x', 'y'], np.random.random((2, 3)))}, - {'x': [0, 1]}) - foo = ds['foo'] - bar = ds['bar'] + ds = Dataset( + { + "foo": (["x", "y"], np.random.random((2, 3))), + "bar": (["x", "y"], np.random.random((2, 3))), + }, + {"x": [0, 1]}, + ) + foo = ds["foo"] + bar = ds["bar"] # from dataset array: - expected = DataArray(np.array([foo.values, bar.values]), - dims=['w', 'x', 'y'], coords={'x': [0, 1]}) - actual = concat([foo, bar], 'w') + expected = DataArray( + np.array([foo.values, bar.values]), + dims=["w", "x", "y"], + coords={"x": [0, 1]}, + ) + actual = concat([foo, bar], "w") assert_equal(expected, actual) # from iteration: - grouped = [g for _, g in foo.groupby('x')] - stacked = concat(grouped, ds['x']) + grouped = [g for _, g in foo.groupby("x")] + stacked = concat(grouped, ds["x"]) assert_identical(foo, stacked) # with an index as the 'dim' argument - stacked = concat(grouped, ds.indexes['x']) + stacked = concat(grouped, ds.indexes["x"]) assert_identical(foo, stacked) - actual = concat([foo[0], foo[1]], pd.Index([0, 1]) - ).reset_coords(drop=True) - expected = foo[:2].rename({'x': 'concat_dim'}) + actual = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) + expected = foo[:2].rename({"x": "concat_dim"}) assert_identical(expected, actual) actual = concat([foo[0], foo[1]], [0, 1]).reset_coords(drop=True) - expected = foo[:2].rename({'x': 'concat_dim'}) + expected = foo[:2].rename({"x": "concat_dim"}) assert_identical(expected, actual) - with raises_regex(ValueError, 'not identical'): - concat([foo, bar], dim='w', compat='identical') + with raises_regex(ValueError, "not identical"): + concat([foo, bar], dim="w", compat="identical") - with raises_regex(ValueError, 'not a valid argument'): - concat([foo, bar], dim='w', data_vars='minimal') + with raises_regex(ValueError, "not a valid argument"): + concat([foo, bar], dim="w", data_vars="minimal") def test_concat_encoding(self): # Regression test for GH1297 - ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), - 'bar': (['x', 'y'], np.random.random((2, 3)))}, - {'x': [0, 1]}) - foo = ds['foo'] + ds = Dataset( + { + "foo": (["x", "y"], np.random.random((2, 3))), + "bar": (["x", "y"], np.random.random((2, 3))), + }, + {"x": [0, 1]}, + ) + foo = ds["foo"] foo.encoding = {"complevel": 5} - ds.encoding = {"unlimited_dims": 'x'} + ds.encoding = {"unlimited_dims": "x"} assert concat([foo, foo], dim="x").encoding == foo.encoding assert concat([ds, ds], dim="x").encoding == ds.encoding @@ -329,49 +348,61 @@ def test_concat_encoding(self): def test_concat_lazy(self): import dask.array as da - arrays = [DataArray( - da.from_array(InaccessibleArray(np.zeros((3, 3))), 3), - dims=['x', 'y']) for _ in range(2)] + arrays = [ + DataArray( + da.from_array(InaccessibleArray(np.zeros((3, 3))), 3), dims=["x", "y"] + ) + for _ in range(2) + ] # should not raise - combined = concat(arrays, dim='z') + combined = concat(arrays, dim="z") assert combined.shape == (2, 3, 3) - assert combined.dims == ('z', 'x', 'y') + assert combined.dims == ("z", "x", "y") - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_concat_fill_value(self, fill_value): - foo = DataArray([1, 2], coords=[('x', [1, 2])]) - bar = DataArray([1, 2], coords=[('x', [1, 3])]) + foo = DataArray([1, 2], coords=[("x", [1, 2])]) + bar = DataArray([1, 2], coords=[("x", [1, 3])]) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = DataArray([[1, 2, fill_value], [1, fill_value, 2]], - dims=['y', 'x'], coords={'x': [1, 2, 3]}) - actual = concat((foo, bar), dim='y', fill_value=fill_value) + expected = DataArray( + [[1, 2, fill_value], [1, fill_value, 2]], + dims=["y", "x"], + coords={"x": [1, 2, 3]}, + ) + actual = concat((foo, bar), dim="y", fill_value=fill_value) assert_identical(actual, expected) def test_concat_join_kwarg(self): - ds1 = Dataset({'a': (('x', 'y'), [[0]])}, - coords={'x': [0], 'y': [0]}).to_array() - ds2 = Dataset({'a': (('x', 'y'), [[0]])}, - coords={'x': [1], 'y': [0.0001]}).to_array() + ds1 = Dataset( + {"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]} + ).to_array() + ds2 = Dataset( + {"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]} + ).to_array() expected = dict() - expected['outer'] = Dataset({'a': (('x', 'y'), - [[0, np.nan], [np.nan, 0]])}, - {'x': [0, 1], 'y': [0, 0.0001]}) - expected['inner'] = Dataset({'a': (('x', 'y'), [[], []])}, - {'x': [0, 1], 'y': []}) - expected['left'] = Dataset({'a': (('x', 'y'), - np.array([0, np.nan], ndmin=2).T)}, - coords={'x': [0, 1], 'y': [0]}) - expected['right'] = Dataset({'a': (('x', 'y'), - np.array([np.nan, 0], ndmin=2).T)}, - coords={'x': [0, 1], 'y': [0.0001]}) + expected["outer"] = Dataset( + {"a": (("x", "y"), [[0, np.nan], [np.nan, 0]])}, + {"x": [0, 1], "y": [0, 0.0001]}, + ) + expected["inner"] = Dataset( + {"a": (("x", "y"), [[], []])}, {"x": [0, 1], "y": []} + ) + expected["left"] = Dataset( + {"a": (("x", "y"), np.array([0, np.nan], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0]}, + ) + expected["right"] = Dataset( + {"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0.0001]}, + ) with raises_regex(ValueError, "indexes along dimension 'y'"): - actual = concat([ds1, ds2], join='exact', dim='x') + actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: - actual = concat([ds1, ds2], join=join, dim='x') + actual = concat([ds1, ds2], join=join, dim="x") assert_equal(actual, expected[join].to_array()) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index e7cb8006b08..36c1d845f8e 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -7,31 +7,42 @@ import pytest from xarray import ( - Dataset, SerializationWarning, Variable, coding, conventions, open_dataset) + Dataset, + SerializationWarning, + Variable, + coding, + conventions, + open_dataset, +) from xarray.backends.common import WritableCFDataStore from xarray.backends.memory import InMemoryDataStore from xarray.conventions import decode_cf from xarray.testing import assert_identical from . import ( - assert_array_equal, raises_regex, requires_cftime_or_netCDF4, - requires_dask, requires_netCDF4) + assert_array_equal, + raises_regex, + requires_cftime_or_netCDF4, + requires_dask, + requires_netCDF4, +) from .test_backends import CFEncodedBase class TestBoolTypeArray: def test_booltype_array(self): - x = np.array([1, 0, 1, 1, 0], dtype='i1') + x = np.array([1, 0, 1, 1, 0], dtype="i1") bx = conventions.BoolTypeArray(x) assert bx.dtype == np.bool - assert_array_equal(bx, np.array([True, False, True, True, False], - dtype=np.bool)) + assert_array_equal( + bx, np.array([True, False, True, True, False], dtype=np.bool) + ) class TestNativeEndiannessArray: def test(self): - x = np.arange(5, dtype='>i8') - expected = np.arange(5, dtype='int64') + x = np.arange(5, dtype=">i8") + expected = np.arange(5, dtype="int64") a = conventions.NativeEndiannessArray(x) assert a.dtype == expected.dtype assert a.dtype == expected[:].dtype @@ -39,30 +50,35 @@ def test(self): def test_decode_cf_with_conflicting_fill_missing_value(): - expected = Variable(['t'], [np.nan, np.nan, 2], {'units': 'foobar'}) - var = Variable(['t'], np.arange(3), - {'units': 'foobar', - 'missing_value': 0, - '_FillValue': 1}) + expected = Variable(["t"], [np.nan, np.nan, 2], {"units": "foobar"}) + var = Variable( + ["t"], np.arange(3), {"units": "foobar", "missing_value": 0, "_FillValue": 1} + ) with warnings.catch_warnings(record=True) as w: - actual = conventions.decode_cf_variable('t', var) + actual = conventions.decode_cf_variable("t", var) assert_identical(actual, expected) - assert 'has multiple fill' in str(w[0].message) + assert "has multiple fill" in str(w[0].message) - expected = Variable(['t'], np.arange(10), {'units': 'foobar'}) + expected = Variable(["t"], np.arange(10), {"units": "foobar"}) - var = Variable(['t'], np.arange(10), - {'units': 'foobar', - 'missing_value': np.nan, - '_FillValue': np.nan}) - actual = conventions.decode_cf_variable('t', var) + var = Variable( + ["t"], + np.arange(10), + {"units": "foobar", "missing_value": np.nan, "_FillValue": np.nan}, + ) + actual = conventions.decode_cf_variable("t", var) assert_identical(actual, expected) - var = Variable(['t'], np.arange(10), - {'units': 'foobar', - 'missing_value': np.float32(np.nan), - '_FillValue': np.float32(np.nan)}) - actual = conventions.decode_cf_variable('t', var) + var = Variable( + ["t"], + np.arange(10), + { + "units": "foobar", + "missing_value": np.float32(np.nan), + "_FillValue": np.float32(np.nan), + }, + ) + actual = conventions.decode_cf_variable("t", var) assert_identical(actual, expected) @@ -70,20 +86,21 @@ def test_decode_cf_with_conflicting_fill_missing_value(): class TestEncodeCFVariable: def test_incompatible_attributes(self): invalid_vars = [ - Variable(['t'], pd.date_range('2000-01-01', periods=3), - {'units': 'foobar'}), - Variable(['t'], pd.to_timedelta(['1 day']), {'units': 'foobar'}), - Variable(['t'], [0, 1, 2], {'add_offset': 0}, {'add_offset': 2}), - Variable(['t'], [0, 1, 2], {'_FillValue': 0}, {'_FillValue': 2}), + Variable( + ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} + ), + Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), + Variable(["t"], [0, 1, 2], {"add_offset": 0}, {"add_offset": 2}), + Variable(["t"], [0, 1, 2], {"_FillValue": 0}, {"_FillValue": 2}), ] for var in invalid_vars: with pytest.raises(ValueError): conventions.encode_cf_variable(var) def test_missing_fillvalue(self): - v = Variable(['x'], np.array([np.nan, 1, 2, 3])) - v.encoding = {'dtype': 'int16'} - with pytest.warns(Warning, match='floating point data as an integer'): + v = Variable(["x"], np.array([np.nan, 1, 2, 3])) + v.encoding = {"dtype": "int16"} + with pytest.warns(Warning, match="floating point data as an integer"): conventions.encode_cf_variable(v) def test_multidimensional_coordinates(self): @@ -93,40 +110,37 @@ def test_multidimensional_coordinates(self): zeros1 = np.zeros((1, 5, 3)) zeros2 = np.zeros((1, 6, 3)) zeros3 = np.zeros((1, 5, 4)) - orig = Dataset({ - 'lon1': (['x1', 'y1'], zeros1.squeeze(0), {}), - 'lon2': (['x2', 'y1'], zeros2.squeeze(0), {}), - 'lon3': (['x1', 'y2'], zeros3.squeeze(0), {}), - 'lat1': (['x1', 'y1'], zeros1.squeeze(0), {}), - 'lat2': (['x2', 'y1'], zeros2.squeeze(0), {}), - 'lat3': (['x1', 'y2'], zeros3.squeeze(0), {}), - 'foo1': (['time', 'x1', 'y1'], zeros1, - {'coordinates': 'lon1 lat1'}), - 'foo2': (['time', 'x2', 'y1'], zeros2, - {'coordinates': 'lon2 lat2'}), - 'foo3': (['time', 'x1', 'y2'], zeros3, - {'coordinates': 'lon3 lat3'}), - 'time': ('time', [0.], {'units': 'hours since 2017-01-01'}), - }) + orig = Dataset( + { + "lon1": (["x1", "y1"], zeros1.squeeze(0), {}), + "lon2": (["x2", "y1"], zeros2.squeeze(0), {}), + "lon3": (["x1", "y2"], zeros3.squeeze(0), {}), + "lat1": (["x1", "y1"], zeros1.squeeze(0), {}), + "lat2": (["x2", "y1"], zeros2.squeeze(0), {}), + "lat3": (["x1", "y2"], zeros3.squeeze(0), {}), + "foo1": (["time", "x1", "y1"], zeros1, {"coordinates": "lon1 lat1"}), + "foo2": (["time", "x2", "y1"], zeros2, {"coordinates": "lon2 lat2"}), + "foo3": (["time", "x1", "y2"], zeros3, {"coordinates": "lon3 lat3"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) orig = conventions.decode_cf(orig) # Encode the coordinates, as they would be in a netCDF output file. enc, attrs = conventions.encode_dataset_coordinates(orig) # Make sure we have the right coordinates for each variable. - foo1_coords = enc['foo1'].attrs.get('coordinates', '') - foo2_coords = enc['foo2'].attrs.get('coordinates', '') - foo3_coords = enc['foo3'].attrs.get('coordinates', '') - assert set(foo1_coords.split()) == {'lat1', 'lon1'} - assert set(foo2_coords.split()) == {'lat2', 'lon2'} - assert set(foo3_coords.split()) == {'lat3', 'lon3'} + foo1_coords = enc["foo1"].attrs.get("coordinates", "") + foo2_coords = enc["foo2"].attrs.get("coordinates", "") + foo3_coords = enc["foo3"].attrs.get("coordinates", "") + assert set(foo1_coords.split()) == {"lat1", "lon1"} + assert set(foo2_coords.split()) == {"lat2", "lon2"} + assert set(foo3_coords.split()) == {"lat3", "lon3"} # Should not have any global coordinates. - assert 'coordinates' not in attrs + assert "coordinates" not in attrs @requires_dask def test_string_object_warning(self): - original = Variable( - ('x',), np.array(['foo', 'bar'], dtype=object)).chunk() - with pytest.warns(SerializationWarning, - match='dask array with dtype=object'): + original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() + with pytest.warns(SerializationWarning, match="dask array with dtype=object"): encoded = conventions.encode_cf_variable(original) assert_identical(original, encoded) @@ -134,113 +148,134 @@ def test_string_object_warning(self): @requires_cftime_or_netCDF4 class TestDecodeCF: def test_dataset(self): - original = Dataset({ - 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), - 'foo': ('t', [0, 0, 0], {'coordinates': 'y', 'units': 'bar'}), - 'y': ('t', [5, 10, -999], {'_FillValue': -999}) - }) - expected = Dataset({'foo': ('t', [0, 0, 0], {'units': 'bar'})}, - {'t': pd.date_range('2000-01-01', periods=3), - 'y': ('t', [5.0, 10.0, np.nan])}) + original = Dataset( + { + "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), + "foo": ("t", [0, 0, 0], {"coordinates": "y", "units": "bar"}), + "y": ("t", [5, 10, -999], {"_FillValue": -999}), + } + ) + expected = Dataset( + {"foo": ("t", [0, 0, 0], {"units": "bar"})}, + { + "t": pd.date_range("2000-01-01", periods=3), + "y": ("t", [5.0, 10.0, np.nan]), + }, + ) actual = conventions.decode_cf(original) assert_identical(expected, actual) def test_invalid_coordinates(self): # regression test for GH308 - original = Dataset({'foo': ('t', [1, 2], {'coordinates': 'invalid'})}) + original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) actual = conventions.decode_cf(original) assert_identical(original, actual) def test_decode_coordinates(self): # regression test for GH610 - original = Dataset({'foo': ('t', [1, 2], {'coordinates': 'x'}), - 'x': ('t', [4, 5])}) + original = Dataset( + {"foo": ("t", [1, 2], {"coordinates": "x"}), "x": ("t", [4, 5])} + ) actual = conventions.decode_cf(original) - assert actual.foo.encoding['coordinates'] == 'x' + assert actual.foo.encoding["coordinates"] == "x" def test_0d_int32_encoding(self): - original = Variable((), np.int32(0), encoding={'dtype': 'int64'}) + original = Variable((), np.int32(0), encoding={"dtype": "int64"}) expected = Variable((), np.int64(0)) actual = conventions.maybe_encode_nonstring_dtype(original) assert_identical(expected, actual) def test_decode_cf_with_multiple_missing_values(self): - original = Variable(['t'], [0, 1, 2], - {'missing_value': np.array([0, 1])}) - expected = Variable(['t'], [np.nan, np.nan, 2], {}) + original = Variable(["t"], [0, 1, 2], {"missing_value": np.array([0, 1])}) + expected = Variable(["t"], [np.nan, np.nan, 2], {}) with warnings.catch_warnings(record=True) as w: - actual = conventions.decode_cf_variable('t', original) + actual = conventions.decode_cf_variable("t", original) assert_identical(expected, actual) - assert 'has multiple fill' in str(w[0].message) + assert "has multiple fill" in str(w[0].message) def test_decode_cf_with_drop_variables(self): - original = Dataset({ - 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), - 'x': ("x", [9, 8, 7], {'units': 'km'}), - 'foo': (('t', 'x'), [[0, 0, 0], [1, 1, 1], [2, 2, 2]], - {'units': 'bar'}), - 'y': ('t', [5, 10, -999], {'_FillValue': -999}) - }) - expected = Dataset({ - 't': pd.date_range('2000-01-01', periods=3), - 'foo': (('t', 'x'), [[0, 0, 0], [1, 1, 1], [2, 2, 2]], - {'units': 'bar'}), - 'y': ('t', [5, 10, np.nan]) - }) + original = Dataset( + { + "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), + "x": ("x", [9, 8, 7], {"units": "km"}), + "foo": ( + ("t", "x"), + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], + {"units": "bar"}, + ), + "y": ("t", [5, 10, -999], {"_FillValue": -999}), + } + ) + expected = Dataset( + { + "t": pd.date_range("2000-01-01", periods=3), + "foo": ( + ("t", "x"), + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], + {"units": "bar"}, + ), + "y": ("t", [5, 10, np.nan]), + } + ) actual = conventions.decode_cf(original, drop_variables=("x",)) actual2 = conventions.decode_cf(original, drop_variables="x") assert_identical(expected, actual) assert_identical(expected, actual2) def test_invalid_time_units_raises_eagerly(self): - ds = Dataset({'time': ('time', [0, 1], {'units': 'foobar since 123'})}) - with raises_regex(ValueError, 'unable to decode time'): + ds = Dataset({"time": ("time", [0, 1], {"units": "foobar since 123"})}) + with raises_regex(ValueError, "unable to decode time"): decode_cf(ds) @requires_cftime_or_netCDF4 def test_dataset_repr_with_netcdf4_datetimes(self): # regression test for #347 - attrs = {'units': 'days since 0001-01-01', 'calendar': 'noleap'} + attrs = {"units": "days since 0001-01-01", "calendar": "noleap"} with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'unable to decode time') - ds = decode_cf(Dataset({'time': ('time', [0, 1], attrs)})) - assert '(time) object' in repr(ds) + warnings.filterwarnings("ignore", "unable to decode time") + ds = decode_cf(Dataset({"time": ("time", [0, 1], attrs)})) + assert "(time) object" in repr(ds) - attrs = {'units': 'days since 1900-01-01'} - ds = decode_cf(Dataset({'time': ('time', [0, 1], attrs)})) - assert '(time) datetime64[ns]' in repr(ds) + attrs = {"units": "days since 1900-01-01"} + ds = decode_cf(Dataset({"time": ("time", [0, 1], attrs)})) + assert "(time) datetime64[ns]" in repr(ds) @requires_cftime_or_netCDF4 def test_decode_cf_datetime_transition_to_invalid(self): # manually create dataset with not-decoded date from datetime import datetime - ds = Dataset(coords={'time': [0, 266 * 365]}) - units = 'days since 2000-01-01 00:00:00' + + ds = Dataset(coords={"time": [0, 266 * 365]}) + units = "days since 2000-01-01 00:00:00" ds.time.attrs = dict(units=units) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'unable to decode time') + warnings.filterwarnings("ignore", "unable to decode time") ds_decoded = conventions.decode_cf(ds) - expected = [datetime(2000, 1, 1, 0, 0), - datetime(2265, 10, 28, 0, 0)] + expected = [datetime(2000, 1, 1, 0, 0), datetime(2265, 10, 28, 0, 0)] assert_array_equal(ds_decoded.time.values, expected) @requires_dask def test_decode_cf_with_dask(self): import dask.array as da - original = Dataset({ - 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), - 'foo': ('t', [0, 0, 0], {'coordinates': 'y', 'units': 'bar'}), - 'bar': ('string2', [b'a', b'b']), - 'baz': (('x'), [b'abc'], {'_Encoding': 'utf-8'}), - 'y': ('t', [5, 10, -999], {'_FillValue': -999}) - }).chunk() + + original = Dataset( + { + "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), + "foo": ("t", [0, 0, 0], {"coordinates": "y", "units": "bar"}), + "bar": ("string2", [b"a", b"b"]), + "baz": (("x"), [b"abc"], {"_Encoding": "utf-8"}), + "y": ("t", [5, 10, -999], {"_FillValue": -999}), + } + ).chunk() decoded = conventions.decode_cf(original) print(decoded) - assert all(isinstance(var.data, da.Array) - for name, var in decoded.variables.items() - if name not in decoded.indexes) + assert all( + isinstance(var.data, da.Array) + for name, var in decoded.variables.items() + if name not in decoded.indexes + ) assert_identical(decoded, conventions.decode_cf(original).compute()) @@ -259,14 +294,14 @@ def create_store(self): yield CFEncodedInMemoryStore() @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): + def roundtrip( + self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False + ): store = CFEncodedInMemoryStore() data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) - @pytest.mark.skip('cannot roundtrip coordinates yet for ' - 'CFEncodedInMemoryStore') + @pytest.mark.skip("cannot roundtrip coordinates yet for " "CFEncodedInMemoryStore") def test_roundtrip_coordinates(self): pass diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 0c55fe919d6..1ae3069f926 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -14,20 +14,27 @@ from xarray.tests import mock from . import ( - assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, - assert_identical, raises_regex) + assert_allclose, + assert_array_equal, + assert_equal, + assert_frame_equal, + assert_identical, + raises_regex, +) -dask = pytest.importorskip('dask') -da = pytest.importorskip('dask.array') -dd = pytest.importorskip('dask.dataframe') +dask = pytest.importorskip("dask") +da = pytest.importorskip("dask.array") +dd = pytest.importorskip("dask.dataframe") class DaskTestCase: def assertLazyAnd(self, expected, actual, test): - with (dask.config.set(scheduler='single-threaded') - if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') - else dask.set_options(get=dask.get)): + with ( + dask.config.set(scheduler="single-threaded") + if LooseVersion(dask.__version__) >= LooseVersion("0.18.0") + else dask.set_options(get=dask.get) + ): test(actual, expected) if isinstance(actual, Dataset): @@ -61,8 +68,8 @@ def setUp(self): self.values = np.random.RandomState(0).randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) - self.eager_var = Variable(('x', 'y'), self.values) - self.lazy_var = Variable(('x', 'y'), self.data) + self.eager_var = Variable(("x", "y"), self.values) + self.lazy_var = Variable(("x", "y"), self.data) def test_basics(self): v = self.lazy_var @@ -72,15 +79,16 @@ def test_basics(self): def test_copy(self): self.assertLazyAndIdentical(self.eager_var, self.lazy_var.copy()) - self.assertLazyAndIdentical(self.eager_var, - self.lazy_var.copy(deep=True)) + self.assertLazyAndIdentical(self.eager_var, self.lazy_var.copy(deep=True)) def test_chunk(self): - for chunks, expected in [(None, ((2, 2), (2, 2, 2))), - (3, ((3, 1), (3, 3))), - ({'x': 3, 'y': 3}, ((3, 1), (3, 3))), - ({'x': 3}, ((3, 1), (2, 2, 2))), - ({'x': (3, 1)}, ((3, 1), (2, 2, 2)))]: + for chunks, expected in [ + (None, ((2, 2), (2, 2, 2))), + (3, ((3, 1), (3, 3))), + ({"x": 3, "y": 3}, ((3, 1), (3, 3))), + ({"x": 3}, ((3, 1), (2, 2, 2))), + ({"x": (3, 1)}, ((3, 1), (2, 2, 2))), + ]: rechunked = self.lazy_var.chunk(chunks) assert rechunked.chunks == expected self.assertLazyAndIdentical(self.eager_var, rechunked) @@ -91,7 +99,7 @@ def test_indexing(self): self.assertLazyAndIdentical(u[0], v[0]) self.assertLazyAndIdentical(u[:1], v[:1]) self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]]) - with raises_regex(TypeError, 'stored in a dask array'): + with raises_regex(TypeError, "stored in a dask array"): v[:1] = 0 def test_squeeze(self): @@ -139,15 +147,17 @@ def test_binary_op(self): self.assertLazyAndIdentical(u[0] + u, v[0] + v) def test_repr(self): - expected = dedent("""\ + expected = dedent( + """\ - dask.array""") + dask.array""" + ) assert expected == repr(self.lazy_var) def test_pickle(self): # Test that pickling/unpickling does not convert the dask # backend to numpy - a1 = Variable(['x'], build_dask_array('x')) + a1 = Variable(["x"], build_dask_array("x")) a1.compute() assert not a1._in_memory assert kernel_call_count == 1 @@ -162,52 +172,51 @@ def test_reduce(self): v = self.lazy_var self.assertLazyAndAllClose(u.mean(), v.mean()) self.assertLazyAndAllClose(u.std(), v.std()) - self.assertLazyAndAllClose(u.argmax(dim='x'), v.argmax(dim='x')) + self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) - self.assertLazyAndAllClose((u < 1).all('x'), (v < 1).all('x')) - with raises_regex(NotImplementedError, 'dask'): + self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) + with raises_regex(NotImplementedError, "dask"): v.median() def test_missing_values(self): values = np.array([0, 1, np.nan, 3]) data = da.from_array(values, chunks=(2,)) - eager_var = Variable('x', values) - lazy_var = Variable('x', data) + eager_var = Variable("x", values) + lazy_var = Variable("x", data) self.assertLazyAndIdentical(eager_var, lazy_var.fillna(lazy_var)) - self.assertLazyAndIdentical(Variable('x', range(4)), - lazy_var.fillna(2)) + self.assertLazyAndIdentical(Variable("x", range(4)), lazy_var.fillna(2)) self.assertLazyAndIdentical(eager_var.count(), lazy_var.count()) def test_concat(self): u = self.eager_var v = self.lazy_var - self.assertLazyAndIdentical(u, Variable.concat([v[:2], v[2:]], 'x')) - self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], v[1]], 'x')) - self.assertLazyAndIdentical(u[:2], Variable.concat([u[0], v[1]], 'x')) - self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], u[1]], 'x')) + self.assertLazyAndIdentical(u, Variable.concat([v[:2], v[2:]], "x")) + self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], v[1]], "x")) + self.assertLazyAndIdentical(u[:2], Variable.concat([u[0], v[1]], "x")) + self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], u[1]], "x")) self.assertLazyAndIdentical( - u[:3], - Variable.concat([v[[0, 2]], v[[1]]], 'x', positions=[[0, 2], [1]])) + u[:3], Variable.concat([v[[0, 2]], v[[1]]], "x", positions=[[0, 2], [1]]) + ) def test_missing_methods(self): v = self.lazy_var try: v.argsort() except NotImplementedError as err: - assert 'dask' in str(err) + assert "dask" in str(err) try: v[0].item() except NotImplementedError as err: - assert 'dask' in str(err) + assert "dask" in str(err) - @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') + @pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") def test_univariate_ufunc(self): u = self.eager_var v = self.lazy_var self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) - @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') + @pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") def test_bivariate_ufunc(self): u = self.eager_var v = self.lazy_var @@ -253,22 +262,24 @@ def assertLazyAndEqual(self, expected, actual): def setUp(self): self.values = np.random.randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) - self.eager_array = DataArray(self.values, coords={'x': range(4)}, - dims=('x', 'y'), name='foo') - self.lazy_array = DataArray(self.data, coords={'x': range(4)}, - dims=('x', 'y'), name='foo') + self.eager_array = DataArray( + self.values, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + self.lazy_array = DataArray( + self.data, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) def test_rechunk(self): - chunked = self.eager_array.chunk({'x': 2}).chunk({'y': 2}) + chunked = self.eager_array.chunk({"x": 2}).chunk({"y": 2}) assert chunked.chunks == ((2,) * 2, (2,) * 3) self.assertLazyAndIdentical(self.lazy_array, chunked) def test_new_chunk(self): chunked = self.eager_array.chunk() - assert chunked.data.name.startswith('xarray-') + assert chunked.data.name.startswith("xarray-") def test_lazy_dataset(self): - lazy_ds = Dataset({'foo': (('x', 'y'), self.data)}) + lazy_ds = Dataset({"foo": (("x", "y"), self.data)}) assert isinstance(lazy_ds.foo.variable.data, da.Array) def test_lazy_array(self): @@ -281,7 +292,7 @@ def test_lazy_array(self): self.assertLazyAndAllClose(u.mean(), v.mean()) self.assertLazyAndAllClose(1 + u, 1 + v) - actual = xr.concat([v[:2], v[2:]], 'x') + actual = xr.concat([v[:2], v[2:]], "x") self.assertLazyAndAllClose(u, actual) def test_compute(self): @@ -311,83 +322,87 @@ def test_persist(self): def test_concat_loads_variables(self): # Test that concat() computes not-in-memory variables at most once # and loads them in the output, while leaving the input unaltered. - d1 = build_dask_array('d1') - c1 = build_dask_array('c1') - d2 = build_dask_array('d2') - c2 = build_dask_array('c2') - d3 = build_dask_array('d3') - c3 = build_dask_array('c3') + d1 = build_dask_array("d1") + c1 = build_dask_array("c1") + d2 = build_dask_array("d2") + c2 = build_dask_array("c2") + d3 = build_dask_array("d3") + c3 = build_dask_array("c3") # Note: c is a non-index coord. # Index coords are loaded by IndexVariable.__init__. - ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)}) - ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)}) - ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)}) + ds1 = Dataset(data_vars={"d": ("x", d1)}, coords={"c": ("x", c1)}) + ds2 = Dataset(data_vars={"d": ("x", d2)}, coords={"c": ("x", c2)}) + ds3 = Dataset(data_vars={"d": ("x", d3)}, coords={"c": ("x", c3)}) assert kernel_call_count == 0 - out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', - coords='different') + out = xr.concat( + [ds1, ds2, ds3], dim="n", data_vars="different", coords="different" + ) # each kernel is computed exactly once assert kernel_call_count == 6 # variables are loaded in the output - assert isinstance(out['d'].data, np.ndarray) - assert isinstance(out['c'].data, np.ndarray) + assert isinstance(out["d"].data, np.ndarray) + assert isinstance(out["c"].data, np.ndarray) - out = xr.concat( - [ds1, ds2, ds3], dim='n', data_vars='all', coords='all') + out = xr.concat([ds1, ds2, ds3], dim="n", data_vars="all", coords="all") # no extra kernel calls assert kernel_call_count == 6 - assert isinstance(out['d'].data, dask.array.Array) - assert isinstance(out['c'].data, dask.array.Array) + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) - out = xr.concat( - [ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c']) + out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=["d"], coords=["c"]) # no extra kernel calls assert kernel_call_count == 6 - assert isinstance(out['d'].data, dask.array.Array) - assert isinstance(out['c'].data, dask.array.Array) + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) - out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[]) + out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=[], coords=[]) # variables are loaded once as we are validing that they're identical assert kernel_call_count == 12 - assert isinstance(out['d'].data, np.ndarray) - assert isinstance(out['c'].data, np.ndarray) + assert isinstance(out["d"].data, np.ndarray) + assert isinstance(out["c"].data, np.ndarray) - out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', - coords='different', compat='identical') + out = xr.concat( + [ds1, ds2, ds3], + dim="n", + data_vars="different", + coords="different", + compat="identical", + ) # compat=identical doesn't do any more kernel calls than compat=equals assert kernel_call_count == 18 - assert isinstance(out['d'].data, np.ndarray) - assert isinstance(out['c'].data, np.ndarray) + assert isinstance(out["d"].data, np.ndarray) + assert isinstance(out["c"].data, np.ndarray) # When the test for different turns true halfway through, # stop computing variables as it would not have any benefit - ds4 = Dataset(data_vars={'d': ('x', [2.0])}, - coords={'c': ('x', [2.0])}) - out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different', - coords='different') + ds4 = Dataset(data_vars={"d": ("x", [2.0])}, coords={"c": ("x", [2.0])}) + out = xr.concat( + [ds1, ds2, ds4, ds3], dim="n", data_vars="different", coords="different" + ) # the variables of ds1 and ds2 were computed, but those of ds3 didn't assert kernel_call_count == 22 - assert isinstance(out['d'].data, dask.array.Array) - assert isinstance(out['c'].data, dask.array.Array) + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) # the data of ds1 and ds2 was loaded into numpy and then # concatenated to the data of ds3. Thus, only ds3 is computed now. out.compute() assert kernel_call_count == 24 # Finally, test that riginals are unaltered - assert ds1['d'].data is d1 - assert ds1['c'].data is c1 - assert ds2['d'].data is d2 - assert ds2['c'].data is c2 - assert ds3['d'].data is d3 - assert ds3['c'].data is c3 + assert ds1["d"].data is d1 + assert ds1["c"].data is c1 + assert ds2["d"].data is d2 + assert ds2["c"].data is c2 + assert ds3["d"].data is d3 + assert ds3["c"].data is c3 def test_groupby(self): u = self.eager_array v = self.lazy_array - expected = u.groupby('x').mean(xr.ALL_DIMS) - actual = v.groupby('x').mean(xr.ALL_DIMS) + expected = u.groupby("x").mean(xr.ALL_DIMS) + actual = v.groupby("x").mean(xr.ALL_DIMS) self.assertLazyAndAllClose(expected, actual) def test_groupby_first(self): @@ -395,20 +410,22 @@ def test_groupby_first(self): v = self.lazy_array for coords in [u.coords, v.coords]: - coords['ab'] = ('x', ['a', 'a', 'b', 'b']) - with raises_regex(NotImplementedError, 'dask'): - v.groupby('ab').first() - expected = u.groupby('ab').first() - actual = v.groupby('ab').first(skipna=False) + coords["ab"] = ("x", ["a", "a", "b", "b"]) + with raises_regex(NotImplementedError, "dask"): + v.groupby("ab").first() + expected = u.groupby("ab").first() + actual = v.groupby("ab").first(skipna=False) self.assertLazyAndAllClose(expected, actual) def test_reindex(self): u = self.eager_array.assign_coords(y=range(6)) v = self.lazy_array.assign_coords(y=range(6)) - for kwargs in [{'x': [2, 3, 4]}, - {'x': [1, 100, 2, 101, 3]}, - {'x': [2.5, 3, 3.5], 'y': [2, 2.5, 3]}]: + for kwargs in [ + {"x": [2, 3, 4]}, + {"x": [1, 100, 2, 101, 3]}, + {"x": [2.5, 3, 3.5], "y": [2, 2.5, 3]}, + ]: expected = u.reindex(**kwargs) actual = v.reindex(**kwargs) self.assertLazyAndAllClose(expected, actual) @@ -417,19 +434,18 @@ def test_to_dataset_roundtrip(self): u = self.eager_array v = self.lazy_array - expected = u.assign_coords(x=u['x']) - self.assertLazyAndEqual(expected, v.to_dataset('x').to_array('x')) + expected = u.assign_coords(x=u["x"]) + self.assertLazyAndEqual(expected, v.to_dataset("x").to_array("x")) def test_merge(self): - def duplicate_and_merge(array): - return xr.merge([array, array.rename('bar')]).to_array() + return xr.merge([array, array.rename("bar")]).to_array() expected = duplicate_and_merge(self.eager_array) actual = duplicate_and_merge(self.lazy_array) self.assertLazyAndEqual(expected, actual) - @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') + @pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") def test_ufuncs(self): u = self.eager_array v = self.lazy_array @@ -446,8 +462,7 @@ def test_where_dispatching(self): self.assertLazyAndEqual(expected, DataArray(x).where(y)) def test_simultaneous_compute(self): - ds = Dataset({'foo': ('x', range(5)), - 'bar': ('x', range(5))}).chunk() + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk() count = [0] @@ -455,7 +470,7 @@ def counting_get(*args, **kwargs): count[0] += 1 return dask.get(*args, **kwargs) - if dask.__version__ < '0.19.4': + if dask.__version__ < "0.19.4": ds.load(get=counting_get) else: ds.load(scheduler=counting_get) @@ -464,11 +479,10 @@ def counting_get(*args, **kwargs): def test_stack(self): data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4)) - arr = DataArray(data, dims=('w', 'x', 'y')) - stacked = arr.stack(z=('x', 'y')) - z = pd.MultiIndex.from_product([np.arange(3), np.arange(4)], - names=['x', 'y']) - expected = DataArray(data.reshape(2, -1), {'z': z}, dims=['w', 'z']) + arr = DataArray(data, dims=("w", "x", "y")) + stacked = arr.stack(z=("x", "y")) + z = pd.MultiIndex.from_product([np.arange(3), np.arange(4)], names=["x", "y"]) + expected = DataArray(data.reshape(2, -1), {"z": z}, dims=["w", "z"]) assert stacked.data.chunks == expected.data.chunks self.assertLazyAndEqual(expected, stacked) @@ -480,94 +494,94 @@ def test_dot(self): def test_dataarray_repr(self): # Test that __repr__ converts the dask backend to numpy # in neither the data variable nor the non-index coords - data = build_dask_array('data') - nonindex_coord = build_dask_array('coord') - a = DataArray(data, dims=['x'], coords={'y': ('x', nonindex_coord)}) - expected = dedent("""\ + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) + expected = dedent( + """\ dask.array Coordinates: y (x) int64 dask.array - Dimensions without coordinates: x""") + Dimensions without coordinates: x""" + ) assert expected == repr(a) assert kernel_call_count == 0 def test_dataset_repr(self): # Test that pickling/unpickling converts the dask backend # to numpy in neither the data variables nor the non-index coords - data = build_dask_array('data') - nonindex_coord = build_dask_array('coord') - ds = Dataset(data_vars={'a': ('x', data)}, - coords={'y': ('x', nonindex_coord)}) - expected = dedent("""\ + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) + expected = dedent( + """\ Dimensions: (x: 1) Coordinates: y (x) int64 dask.array Dimensions without coordinates: x Data variables: - a (x) int64 dask.array""") + a (x) int64 dask.array""" + ) assert expected == repr(ds) assert kernel_call_count == 0 def test_dataarray_pickle(self): # Test that pickling/unpickling converts the dask backend # to numpy in neither the data variable nor the non-index coords - data = build_dask_array('data') - nonindex_coord = build_dask_array('coord') - a1 = DataArray(data, dims=['x'], coords={'y': ('x', nonindex_coord)}) + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + a1 = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) a1.compute() assert not a1._in_memory - assert not a1.coords['y']._in_memory + assert not a1.coords["y"]._in_memory assert kernel_call_count == 2 a2 = pickle.loads(pickle.dumps(a1)) assert kernel_call_count == 2 assert_identical(a1, a2) assert not a1._in_memory assert not a2._in_memory - assert not a1.coords['y']._in_memory - assert not a2.coords['y']._in_memory + assert not a1.coords["y"]._in_memory + assert not a2.coords["y"]._in_memory def test_dataset_pickle(self): # Test that pickling/unpickling converts the dask backend # to numpy in neither the data variables nor the non-index coords - data = build_dask_array('data') - nonindex_coord = build_dask_array('coord') - ds1 = Dataset(data_vars={'a': ('x', data)}, - coords={'y': ('x', nonindex_coord)}) + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + ds1 = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) ds1.compute() - assert not ds1['a']._in_memory - assert not ds1['y']._in_memory + assert not ds1["a"]._in_memory + assert not ds1["y"]._in_memory assert kernel_call_count == 2 ds2 = pickle.loads(pickle.dumps(ds1)) assert kernel_call_count == 2 assert_identical(ds1, ds2) - assert not ds1['a']._in_memory - assert not ds2['a']._in_memory - assert not ds1['y']._in_memory - assert not ds2['y']._in_memory + assert not ds1["a"]._in_memory + assert not ds2["a"]._in_memory + assert not ds1["y"]._in_memory + assert not ds2["y"]._in_memory def test_dataarray_getattr(self): # ipython/jupyter does a long list of getattr() calls to when trying to # represent an object. # Make sure we're not accidentally computing dask variables. - data = build_dask_array('data') - nonindex_coord = build_dask_array('coord') - a = DataArray(data, dims=['x'], - coords={'y': ('x', nonindex_coord)}) + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) with suppress(AttributeError): - getattr(a, 'NOTEXIST') + getattr(a, "NOTEXIST") assert kernel_call_count == 0 def test_dataset_getattr(self): # Test that pickling/unpickling converts the dask backend # to numpy in neither the data variables nor the non-index coords - data = build_dask_array('data') - nonindex_coord = build_dask_array('coord') - ds = Dataset(data_vars={'a': ('x', data)}, - coords={'y': ('x', nonindex_coord)}) + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) with suppress(AttributeError): - getattr(ds, 'NOTEXIST') + getattr(ds, "NOTEXIST") assert kernel_call_count == 0 def test_values(self): @@ -581,26 +595,20 @@ def test_values(self): def test_from_dask_variable(self): # Test array creation from Variable with dask backend. # This is used e.g. in broadcast() - a = DataArray(self.lazy_array.variable, - coords={'x': range(4)}, name='foo') + a = DataArray(self.lazy_array.variable, coords={"x": range(4)}, name="foo") self.assertLazyAndIdentical(self.lazy_array, a) class TestToDaskDataFrame: - def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames x = da.from_array(np.random.randn(10), chunks=4) - y = np.arange(10, dtype='uint8') - t = list('abcdefghij') + y = np.arange(10, dtype="uint8") + t = list("abcdefghij") - ds = Dataset(OrderedDict([('a', ('t', x)), - ('b', ('t', y)), - ('t', ('t', t))])) + ds = Dataset(OrderedDict([("a", ("t", x)), ("b", ("t", y)), ("t", ("t", t))])) - expected_pd = pd.DataFrame({'a': x, - 'b': y}, - index=pd.Index(t, name='t')) + expected_pd = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t")) # test if 1-D index is correctly set up expected = dd.from_pandas(expected_pd, chunksize=4) @@ -612,8 +620,7 @@ def test_to_dask_dataframe(self): assert_frame_equal(expected.compute(), actual.compute()) # test if no index is given - expected = dd.from_pandas(expected_pd.reset_index(drop=False), - chunksize=4) + expected = dd.from_pandas(expected_pd.reset_index(drop=False), chunksize=4) actual = ds.to_dask_dataframe(set_index=False) @@ -623,17 +630,16 @@ def test_to_dask_dataframe(self): def test_to_dask_dataframe_2D(self): # Test if 2-D dataset is supplied w = da.from_array(np.random.randn(2, 3), chunks=(1, 2)) - ds = Dataset({'w': (('x', 'y'), w)}) - ds['x'] = ('x', np.array([0, 1], np.int64)) - ds['y'] = ('y', list('abc')) + ds = Dataset({"w": (("x", "y"), w)}) + ds["x"] = ("x", np.array([0, 1], np.int64)) + ds["y"] = ("y", list("abc")) # dask dataframes do not (yet) support multiindex, # but when it does, this would be the expected index: exp_index = pd.MultiIndex.from_arrays( - [[0, 0, 0, 1, 1, 1], ['a', 'b', 'c', 'a', 'b', 'c']], - names=['x', 'y']) - expected = pd.DataFrame({'w': w.reshape(-1)}, - index=exp_index) + [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] + ) + expected = pd.DataFrame({"w": w.reshape(-1)}, index=exp_index) # so for now, reset the index expected = expected.reset_index(drop=False) actual = ds.to_dask_dataframe(set_index=False) @@ -645,9 +651,9 @@ def test_to_dask_dataframe_2D(self): def test_to_dask_dataframe_2D_set_index(self): # This will fail until dask implements MultiIndex support w = da.from_array(np.random.randn(2, 3), chunks=(1, 2)) - ds = Dataset({'w': (('x', 'y'), w)}) - ds['x'] = ('x', np.array([0, 1], np.int64)) - ds['y'] = ('y', list('abc')) + ds = Dataset({"w": (("x", "y"), w)}) + ds["x"] = ("x", np.array([0, 1], np.int64)) + ds["y"] = ("y", list("abc")) expected = ds.compute().to_dataframe() actual = ds.to_dask_dataframe(set_index=True) @@ -659,11 +665,9 @@ def test_to_dask_dataframe_coordinates(self): x = da.from_array(np.random.randn(10), chunks=4) t = da.from_array(np.arange(10) * 2, chunks=4) - ds = Dataset(OrderedDict([('a', ('t', x)), - ('t', ('t', t))])) + ds = Dataset(OrderedDict([("a", ("t", x)), ("t", ("t", t))])) - expected_pd = pd.DataFrame({'a': x}, - index=pd.Index(t, name='t')) + expected_pd = pd.DataFrame({"a": x}, index=pd.Index(t, name="t")) expected = dd.from_pandas(expected_pd, chunksize=4) actual = ds.to_dask_dataframe(set_index=True) assert isinstance(actual, dd.DataFrame) @@ -672,15 +676,12 @@ def test_to_dask_dataframe_coordinates(self): def test_to_dask_dataframe_not_daskarray(self): # Test if DataArray is not a dask array x = np.random.randn(10) - y = np.arange(10, dtype='uint8') - t = list('abcdefghij') + y = np.arange(10, dtype="uint8") + t = list("abcdefghij") - ds = Dataset(OrderedDict([('a', ('t', x)), - ('b', ('t', y)), - ('t', ('t', t))])) + ds = Dataset(OrderedDict([("a", ("t", x)), ("b", ("t", y)), ("t", ("t", t))])) - expected = pd.DataFrame({'a': x, 'b': y}, - index=pd.Index(t, name='t')) + expected = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t")) actual = ds.to_dask_dataframe(set_index=True) assert isinstance(actual, dd.DataFrame) @@ -688,7 +689,7 @@ def test_to_dask_dataframe_not_daskarray(self): def test_to_dask_dataframe_no_coordinate(self): x = da.from_array(np.random.randn(10), chunks=4) - ds = Dataset({'x': ('dim_0', x)}) + ds = Dataset({"x": ("dim_0", x)}) expected = ds.compute().to_dataframe().reset_index() actual = ds.to_dask_dataframe() @@ -702,58 +703,59 @@ def test_to_dask_dataframe_no_coordinate(self): def test_to_dask_dataframe_dim_order(self): values = np.array([[1, 2], [3, 4]], dtype=np.int64) - ds = Dataset({'w': (('x', 'y'), values)}).chunk(1) + ds = Dataset({"w": (("x", "y"), values)}).chunk(1) - expected = ds['w'].to_series().reset_index() - actual = ds.to_dask_dataframe(dim_order=['x', 'y']) + expected = ds["w"].to_series().reset_index() + actual = ds.to_dask_dataframe(dim_order=["x", "y"]) assert isinstance(actual, dd.DataFrame) assert_frame_equal(expected, actual.compute()) - expected = ds['w'].T.to_series().reset_index() - actual = ds.to_dask_dataframe(dim_order=['y', 'x']) + expected = ds["w"].T.to_series().reset_index() + actual = ds.to_dask_dataframe(dim_order=["y", "x"]) assert isinstance(actual, dd.DataFrame) assert_frame_equal(expected, actual.compute()) - with raises_regex(ValueError, 'does not match the set of dimensions'): - ds.to_dask_dataframe(dim_order=['x']) + with raises_regex(ValueError, "does not match the set of dimensions"): + ds.to_dask_dataframe(dim_order=["x"]) -@pytest.mark.parametrize("method", ['load', 'compute']) +@pytest.mark.parametrize("method", ["load", "compute"]) def test_dask_kwargs_variable(method): - x = Variable('y', da.from_array(np.arange(3), chunks=(2,))) + x = Variable("y", da.from_array(np.arange(3), chunks=(2,))) # args should be passed on to da.Array.compute() - with mock.patch.object(da.Array, 'compute', - return_value=np.arange(3)) as mock_compute: - getattr(x, method)(foo='bar') - mock_compute.assert_called_with(foo='bar') + with mock.patch.object( + da.Array, "compute", return_value=np.arange(3) + ) as mock_compute: + getattr(x, method)(foo="bar") + mock_compute.assert_called_with(foo="bar") -@pytest.mark.parametrize("method", ['load', 'compute', 'persist']) +@pytest.mark.parametrize("method", ["load", "compute", "persist"]) def test_dask_kwargs_dataarray(method): data = da.from_array(np.arange(3), chunks=(2,)) x = DataArray(data) - if method in ['load', 'compute']: - dask_func = 'dask.array.compute' + if method in ["load", "compute"]: + dask_func = "dask.array.compute" else: - dask_func = 'dask.persist' + dask_func = "dask.persist" # args should be passed on to "dask_func" with mock.patch(dask_func) as mock_func: - getattr(x, method)(foo='bar') - mock_func.assert_called_with(data, foo='bar') + getattr(x, method)(foo="bar") + mock_func.assert_called_with(data, foo="bar") -@pytest.mark.parametrize("method", ['load', 'compute', 'persist']) +@pytest.mark.parametrize("method", ["load", "compute", "persist"]) def test_dask_kwargs_dataset(method): data = da.from_array(np.arange(3), chunks=(2,)) - x = Dataset({'x': (('y'), data)}) - if method in ['load', 'compute']: - dask_func = 'dask.array.compute' + x = Dataset({"x": (("y"), data)}) + if method in ["load", "compute"]: + dask_func = "dask.array.compute" else: - dask_func = 'dask.persist' + dask_func = "dask.persist" # args should be passed on to "dask_func" with mock.patch(dask_func) as mock_func: - getattr(x, method)(foo='bar') - mock_func.assert_called_with(data, foo='bar') + getattr(x, method)(foo="bar") + mock_func.assert_called_with(data, foo="bar") kernel_call_count = 0 @@ -773,17 +775,17 @@ def build_dask_array(name): global kernel_call_count kernel_call_count = 0 return dask.array.Array( - dask={(name, 0): (kernel, name)}, name=name, - chunks=((1,),), dtype=np.int64) + dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64 + ) # test both the perist method and the dask.persist function # the dask.persist function requires a new version of dask -@pytest.mark.parametrize('persist', [lambda x: x.persist(), - lambda x: dask.persist(x)[0]]) +@pytest.mark.parametrize( + "persist", [lambda x: x.persist(), lambda x: dask.persist(x)[0]] +) def test_persist_Dataset(persist): - ds = Dataset({'foo': ('x', range(5)), - 'bar': ('x', range(5))}).chunk() + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk() ds = ds + 1 n = len(ds.foo.data.dask) @@ -793,8 +795,9 @@ def test_persist_Dataset(persist): assert len(ds.foo.data.dask) == n # doesn't mutate in place -@pytest.mark.parametrize('persist', [lambda x: x.persist(), - lambda x: dask.persist(x)[0]]) +@pytest.mark.parametrize( + "persist", [lambda x: x.persist(), lambda x: dask.persist(x)[0]] +) def test_persist_DataArray(persist): x = da.arange(10, chunks=(5,)) y = DataArray(x) @@ -809,49 +812,49 @@ def test_persist_DataArray(persist): def test_dataarray_with_dask_coords(): import toolz - x = xr.Variable('x', da.arange(8, chunks=(4,))) - y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2) + + x = xr.Variable("x", da.arange(8, chunks=(4,))) + y = xr.Variable("y", da.arange(8, chunks=(4,)) * 2) data = da.random.random((8, 8), chunks=(4, 4)) + 1 - array = xr.DataArray(data, dims=['x', 'y']) - array.coords['xx'] = x - array.coords['yy'] = y + array = xr.DataArray(data, dims=["x", "y"]) + array.coords["xx"] = x + array.coords["yy"] = y - assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(), - x.__dask_graph__(), - y.__dask_graph__()) + assert dict(array.__dask_graph__()) == toolz.merge( + data.__dask_graph__(), x.__dask_graph__(), y.__dask_graph__() + ) (array2,) = dask.compute(array) assert not dask.is_dask_collection(array2) - assert all(isinstance(v._variable.data, np.ndarray) - for v in array2.coords.values()) + assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values()) def test_basic_compute(): - ds = Dataset({'foo': ('x', range(5)), - 'bar': ('x', range(5))}).chunk({'x': 2}) - for get in [dask.threaded.get, - dask.multiprocessing.get, - dask.local.get_sync, - None]: - with (dask.config.set(scheduler=get) - if LooseVersion(dask.__version__) >= LooseVersion('0.19.4') - else dask.config.set(scheduler=get) - if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') - else dask.set_options(get=get)): + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk({"x": 2}) + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with ( + dask.config.set(scheduler=get) + if LooseVersion(dask.__version__) >= LooseVersion("0.19.4") + else dask.config.set(scheduler=get) + if LooseVersion(dask.__version__) >= LooseVersion("0.18.0") + else dask.set_options(get=get) + ): ds.compute() ds.foo.compute() ds.foo.variable.compute() -@pytest.mark.skipif(LooseVersion(dask.__version__) < LooseVersion('0.20.0'), - reason='needs newer dask') +@pytest.mark.skipif( + LooseVersion(dask.__version__) < LooseVersion("0.20.0"), reason="needs newer dask" +) def test_dask_layers_and_dependencies(): - ds = Dataset({'foo': ('x', range(5)), - 'bar': ('x', range(5))}).chunk() + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk() x = dask.delayed(ds) assert set(x.__dask_graph__().dependencies).issuperset( - ds.__dask_graph__().dependencies) + ds.__dask_graph__().dependencies + ) assert set(x.foo.__dask_graph__().dependencies).issuperset( - ds.__dask_graph__().dependencies) + ds.__dask_graph__().dependencies + ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3a19c229fe6..8b63b650dc2 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -10,39 +10,53 @@ import pytest import xarray as xr -from xarray import ( - DataArray, Dataset, IndexVariable, Variable, align, broadcast) +from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast from xarray.coding.times import CFDatetimeCoder from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import ALL_DIMS, full_like from xarray.tests import ( - LooseVersion, ReturnItem, assert_allclose, assert_array_equal, - assert_equal, assert_identical, raises_regex, requires_bottleneck, - requires_cftime, requires_dask, requires_iris, requires_np113, - requires_numbagg, requires_scipy, source_ndarray) + LooseVersion, + ReturnItem, + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + raises_regex, + requires_bottleneck, + requires_cftime, + requires_dask, + requires_iris, + requires_np113, + requires_numbagg, + requires_scipy, + source_ndarray, +) class TestDataArray: @pytest.fixture(autouse=True) def setup(self): - self.attrs = {'attr1': 'value1', 'attr2': 2929} + self.attrs = {"attr1": "value1", "attr2": 2929} self.x = np.random.random((10, 20)) - self.v = Variable(['x', 'y'], self.x) - self.va = Variable(['x', 'y'], self.x, self.attrs) - self.ds = Dataset({'foo': self.v}) - self.dv = self.ds['foo'] + self.v = Variable(["x", "y"], self.x) + self.va = Variable(["x", "y"], self.x, self.attrs) + self.ds = Dataset({"foo": self.v}) + self.dv = self.ds["foo"] - self.mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=('level_1', 'level_2')) - self.mda = DataArray([0, 1, 2, 3], coords={'x': self.mindex}, dims='x') + self.mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + self.mda = DataArray([0, 1, 2, 3], coords={"x": self.mindex}, dims="x") def test_repr(self): - v = Variable(['time', 'x'], [[1, 2, 3], [4, 5, 6]], {'foo': 'bar'}) - coords = OrderedDict([('x', np.arange(3, dtype=np.int64)), - ('other', np.int64(0))]) - data_array = DataArray(v, coords, name='my_variable') - expected = dedent("""\ + v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) + coords = OrderedDict( + [("x", np.arange(3, dtype=np.int64)), ("other", np.int64(0))] + ) + data_array = DataArray(v, coords, name="my_variable") + expected = dedent( + """\ array([[1, 2, 3], [4, 5, 6]]) @@ -51,40 +65,44 @@ def test_repr(self): other int64 0 Dimensions without coordinates: time Attributes: - foo: bar""") + foo: bar""" + ) assert expected == repr(data_array) def test_repr_multiindex(self): - expected = dedent("""\ + expected = dedent( + """\ array([0, 1, 2, 3]) Coordinates: * x (x) MultiIndex - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2""") + - level_2 (x) int64 1 2 1 2""" + ) assert expected == repr(self.mda) def test_repr_multiindex_long(self): mindex_long = pd.MultiIndex.from_product( - [['a', 'b', 'c', 'd'], [1, 2, 3, 4, 5, 6, 7, 8]], - names=('level_1', 'level_2')) - mda_long = DataArray(list(range(32)), - coords={'x': mindex_long}, dims='x') - expected = dedent("""\ + [["a", "b", "c", "d"], [1, 2, 3, 4, 5, 6, 7, 8]], + names=("level_1", "level_2"), + ) + mda_long = DataArray(list(range(32)), coords={"x": mindex_long}, dims="x") + expected = dedent( + """\ array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) Coordinates: * x (x) MultiIndex - level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' - - level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""") # noqa: E501 + - level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" # noqa: E501 + ) assert expected == repr(mda_long) def test_properties(self): assert_equal(self.dv.variable, self.v) assert_array_equal(self.dv.values, self.v.values) - for attr in ['dims', 'dtype', 'shape', 'size', 'nbytes', - 'ndim', 'attrs']: + for attr in ["dims", "dtype", "shape", "size", "nbytes", "ndim", "attrs"]: assert getattr(self.dv, attr) == getattr(self.v, attr) assert len(self.dv) == len(self.v) assert_equal(self.dv.variable, self.v) @@ -93,9 +111,9 @@ def test_properties(self): assert_array_equal(v, self.ds.coords[k]) with pytest.raises(AttributeError): self.dv.dataset - assert isinstance(self.ds['x'].to_index(), pd.Index) - with raises_regex(ValueError, 'must be 1-dimensional'): - self.ds['foo'].to_index() + assert isinstance(self.ds["x"].to_index(), pd.Index) + with raises_regex(ValueError, "must be 1-dimensional"): + self.ds["foo"].to_index() with pytest.raises(AttributeError): self.dv.variable = self.v @@ -109,25 +127,24 @@ def test_data_property(self): assert_array_equal(actual.data, actual.values) def test_indexes(self): - array = DataArray(np.zeros((2, 3)), - [('x', [0, 1]), ('y', ['a', 'b', 'c'])]) - expected = OrderedDict([('x', pd.Index([0, 1])), - ('y', pd.Index(['a', 'b', 'c']))]) + array = DataArray(np.zeros((2, 3)), [("x", [0, 1]), ("y", ["a", "b", "c"])]) + expected = OrderedDict( + [("x", pd.Index([0, 1])), ("y", pd.Index(["a", "b", "c"]))] + ) assert array.indexes.keys() == expected.keys() for k in expected: assert array.indexes[k].equals(expected[k]) def test_get_index(self): - array = DataArray(np.zeros((2, 3)), coords={'x': ['a', 'b']}, - dims=['x', 'y']) - assert array.get_index('x').equals(pd.Index(['a', 'b'])) - assert array.get_index('y').equals(pd.Index([0, 1, 2])) + array = DataArray(np.zeros((2, 3)), coords={"x": ["a", "b"]}, dims=["x", "y"]) + assert array.get_index("x").equals(pd.Index(["a", "b"])) + assert array.get_index("y").equals(pd.Index([0, 1, 2])) with pytest.raises(KeyError): - array.get_index('z') + array.get_index("z") def test_get_index_size_zero(self): - array = DataArray(np.zeros((0,)), dims=['x']) - actual = array.get_index('x') + array = DataArray(np.zeros((0,)), dims=["x"]) + actual = array.get_index("x") expected = pd.Index([], dtype=np.int64) assert actual.equals(expected) assert actual.dtype == expected.dtype @@ -141,78 +158,91 @@ def test_struct_array_dims(self): # checking array subraction when dims are the same # note: names need to be in sorted order to align consistently with # pandas < 0.24 and >= 0.24. - p_data = np.array([('Abe', 180), ('Stacy', 150), ('Dick', 200)], - dtype=[('name', '|S256'), ('height', object)]) - weights_0 = DataArray([80, 56, 120], dims=['participant'], - coords={'participant': p_data}) - weights_1 = DataArray([81, 52, 115], dims=['participant'], - coords={'participant': p_data}) + p_data = np.array( + [("Abe", 180), ("Stacy", 150), ("Dick", 200)], + dtype=[("name", "|S256"), ("height", object)], + ) + weights_0 = DataArray( + [80, 56, 120], dims=["participant"], coords={"participant": p_data} + ) + weights_1 = DataArray( + [81, 52, 115], dims=["participant"], coords={"participant": p_data} + ) actual = weights_1 - weights_0 - expected = DataArray([1, -4, -5], dims=['participant'], - coords={'participant': p_data}) + expected = DataArray( + [1, -4, -5], dims=["participant"], coords={"participant": p_data} + ) assert_identical(actual, expected) # checking array subraction when dims are not the same - p_data_alt = np.array([('Abe', 180), ('Stacy', 151), ('Dick', 200)], - dtype=[('name', '|S256'), ('height', object)]) - weights_1 = DataArray([81, 52, 115], dims=['participant'], - coords={'participant': p_data_alt}) + p_data_alt = np.array( + [("Abe", 180), ("Stacy", 151), ("Dick", 200)], + dtype=[("name", "|S256"), ("height", object)], + ) + weights_1 = DataArray( + [81, 52, 115], dims=["participant"], coords={"participant": p_data_alt} + ) actual = weights_1 - weights_0 - expected = DataArray([1, -5], dims=['participant'], - coords={'participant': p_data[[0, 2]]}) + expected = DataArray( + [1, -5], dims=["participant"], coords={"participant": p_data[[0, 2]]} + ) assert_identical(actual, expected) # checking array subraction when dims are not the same and one # is np.nan - p_data_nan = np.array([('Abe', 180), ('Stacy', np.nan), ('Dick', 200)], - dtype=[('name', '|S256'), ('height', object)]) - weights_1 = DataArray([81, 52, 115], dims=['participant'], - coords={'participant': p_data_nan}) + p_data_nan = np.array( + [("Abe", 180), ("Stacy", np.nan), ("Dick", 200)], + dtype=[("name", "|S256"), ("height", object)], + ) + weights_1 = DataArray( + [81, 52, 115], dims=["participant"], coords={"participant": p_data_nan} + ) actual = weights_1 - weights_0 - expected = DataArray([1, -5], dims=['participant'], - coords={'participant': p_data[[0, 2]]}) + expected = DataArray( + [1, -5], dims=["participant"], coords={"participant": p_data[[0, 2]]} + ) assert_identical(actual, expected) def test_name(self): arr = self.dv - assert arr.name == 'foo' + assert arr.name == "foo" copied = arr.copy() - arr.name = 'bar' - assert arr.name == 'bar' + arr.name = "bar" + assert arr.name == "bar" assert_equal(copied, arr) - actual = DataArray(IndexVariable('x', [3])) - actual.name = 'y' - expected = DataArray([3], [('x', [3])], name='y') + actual = DataArray(IndexVariable("x", [3])) + actual.name = "y" + expected = DataArray([3], [("x", [3])], name="y") assert_identical(actual, expected) def test_dims(self): arr = self.dv - assert arr.dims == ('x', 'y') + assert arr.dims == ("x", "y") - with raises_regex(AttributeError, 'you cannot assign'): - arr.dims = ('w', 'z') + with raises_regex(AttributeError, "you cannot assign"): + arr.dims = ("w", "z") def test_sizes(self): - array = DataArray(np.zeros((3, 4)), dims=['x', 'y']) - assert array.sizes == {'x': 3, 'y': 4} + array = DataArray(np.zeros((3, 4)), dims=["x", "y"]) + assert array.sizes == {"x": 3, "y": 4} assert tuple(array.sizes) == array.dims with pytest.raises(TypeError): - array.sizes['foo'] = 5 + array.sizes["foo"] = 5 def test_encoding(self): - expected = {'foo': 'bar'} - self.dv.encoding['foo'] = 'bar' + expected = {"foo": "bar"} + self.dv.encoding["foo"] = "bar" assert expected == self.dv.encoding - expected = {'baz': 0} + expected = {"baz": 0} self.dv.encoding = expected assert expected is not self.dv.encoding @@ -221,138 +251,146 @@ def test_constructor(self): data = np.random.random((2, 3)) actual = DataArray(data) - expected = Dataset({None: (['dim_0', 'dim_1'], data)})[None] + expected = Dataset({None: (["dim_0", "dim_1"], data)})[None] assert_identical(expected, actual) - actual = DataArray(data, [['a', 'b'], [-1, -2, -3]]) - expected = Dataset({None: (['dim_0', 'dim_1'], data), - 'dim_0': ('dim_0', ['a', 'b']), - 'dim_1': ('dim_1', [-1, -2, -3])})[None] + actual = DataArray(data, [["a", "b"], [-1, -2, -3]]) + expected = Dataset( + { + None: (["dim_0", "dim_1"], data), + "dim_0": ("dim_0", ["a", "b"]), + "dim_1": ("dim_1", [-1, -2, -3]), + } + )[None] assert_identical(expected, actual) - actual = DataArray(data, [pd.Index(['a', 'b'], name='x'), - pd.Index([-1, -2, -3], name='y')]) - expected = Dataset({None: (['x', 'y'], data), - 'x': ('x', ['a', 'b']), - 'y': ('y', [-1, -2, -3])})[None] + actual = DataArray( + data, [pd.Index(["a", "b"], name="x"), pd.Index([-1, -2, -3], name="y")] + ) + expected = Dataset( + {None: (["x", "y"], data), "x": ("x", ["a", "b"]), "y": ("y", [-1, -2, -3])} + )[None] assert_identical(expected, actual) - coords = [['a', 'b'], [-1, -2, -3]] - actual = DataArray(data, coords, ['x', 'y']) + coords = [["a", "b"], [-1, -2, -3]] + actual = DataArray(data, coords, ["x", "y"]) assert_identical(expected, actual) - coords = [pd.Index(['a', 'b'], name='A'), - pd.Index([-1, -2, -3], name='B')] - actual = DataArray(data, coords, ['x', 'y']) + coords = [pd.Index(["a", "b"], name="A"), pd.Index([-1, -2, -3], name="B")] + actual = DataArray(data, coords, ["x", "y"]) assert_identical(expected, actual) - coords = {'x': ['a', 'b'], 'y': [-1, -2, -3]} - actual = DataArray(data, coords, ['x', 'y']) + coords = {"x": ["a", "b"], "y": [-1, -2, -3]} + actual = DataArray(data, coords, ["x", "y"]) assert_identical(expected, actual) - coords = [('x', ['a', 'b']), ('y', [-1, -2, -3])] + coords = [("x", ["a", "b"]), ("y", [-1, -2, -3])] actual = DataArray(data, coords) assert_identical(expected, actual) - expected = Dataset({None: (['x', 'y'], data), - 'x': ('x', ['a', 'b'])})[None] - actual = DataArray(data, {'x': ['a', 'b']}, ['x', 'y']) + expected = Dataset({None: (["x", "y"], data), "x": ("x", ["a", "b"])})[None] + actual = DataArray(data, {"x": ["a", "b"]}, ["x", "y"]) assert_identical(expected, actual) - actual = DataArray(data, dims=['x', 'y']) - expected = Dataset({None: (['x', 'y'], data)})[None] + actual = DataArray(data, dims=["x", "y"]) + expected = Dataset({None: (["x", "y"], data)})[None] assert_identical(expected, actual) - actual = DataArray(data, dims=['x', 'y'], name='foo') - expected = Dataset({'foo': (['x', 'y'], data)})['foo'] + actual = DataArray(data, dims=["x", "y"], name="foo") + expected = Dataset({"foo": (["x", "y"], data)})["foo"] assert_identical(expected, actual) - actual = DataArray(data, name='foo') - expected = Dataset({'foo': (['dim_0', 'dim_1'], data)})['foo'] + actual = DataArray(data, name="foo") + expected = Dataset({"foo": (["dim_0", "dim_1"], data)})["foo"] assert_identical(expected, actual) - actual = DataArray(data, dims=['x', 'y'], attrs={'bar': 2}) - expected = Dataset({None: (['x', 'y'], data, {'bar': 2})})[None] + actual = DataArray(data, dims=["x", "y"], attrs={"bar": 2}) + expected = Dataset({None: (["x", "y"], data, {"bar": 2})})[None] assert_identical(expected, actual) - actual = DataArray(data, dims=['x', 'y']) - expected = Dataset({None: (['x', 'y'], data, {}, {'bar': 2})})[None] + actual = DataArray(data, dims=["x", "y"]) + expected = Dataset({None: (["x", "y"], data, {}, {"bar": 2})})[None] assert_identical(expected, actual) def test_constructor_invalid(self): data = np.random.randn(3, 2) - with raises_regex(ValueError, 'coords is not dict-like'): - DataArray(data, [[0, 1, 2]], ['x', 'y']) + with raises_regex(ValueError, "coords is not dict-like"): + DataArray(data, [[0, 1, 2]], ["x", "y"]) - with raises_regex(ValueError, 'not a subset of the .* dim'): - DataArray(data, {'x': [0, 1, 2]}, ['a', 'b']) - with raises_regex(ValueError, 'not a subset of the .* dim'): - DataArray(data, {'x': [0, 1, 2]}) + with raises_regex(ValueError, "not a subset of the .* dim"): + DataArray(data, {"x": [0, 1, 2]}, ["a", "b"]) + with raises_regex(ValueError, "not a subset of the .* dim"): + DataArray(data, {"x": [0, 1, 2]}) - with raises_regex(TypeError, 'is not a string'): - DataArray(data, dims=['x', None]) + with raises_regex(TypeError, "is not a string"): + DataArray(data, dims=["x", None]) - with raises_regex(ValueError, 'conflicting sizes for dim'): - DataArray([1, 2, 3], coords=[('x', [0, 1])]) - with raises_regex(ValueError, 'conflicting sizes for dim'): - DataArray([1, 2], coords={'x': [0, 1], 'y': ('x', [1])}, dims='x') + with raises_regex(ValueError, "conflicting sizes for dim"): + DataArray([1, 2, 3], coords=[("x", [0, 1])]) + with raises_regex(ValueError, "conflicting sizes for dim"): + DataArray([1, 2], coords={"x": [0, 1], "y": ("x", [1])}, dims="x") - with raises_regex(ValueError, 'conflicting MultiIndex'): - DataArray(np.random.rand(4, 4), - [('x', self.mindex), ('y', self.mindex)]) - with raises_regex(ValueError, 'conflicting MultiIndex'): - DataArray(np.random.rand(4, 4), - [('x', self.mindex), ('level_1', range(4))]) + with raises_regex(ValueError, "conflicting MultiIndex"): + DataArray(np.random.rand(4, 4), [("x", self.mindex), ("y", self.mindex)]) + with raises_regex(ValueError, "conflicting MultiIndex"): + DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) - with raises_regex(ValueError, 'matching the dimension size'): - DataArray(data, coords={'x': 0}, dims=['x', 'y']) + with raises_regex(ValueError, "matching the dimension size"): + DataArray(data, coords={"x": 0}, dims=["x", "y"]) def test_constructor_from_self_described(self): data = [[-0.1, 21], [0, 2]] - expected = DataArray(data, - coords={'x': ['a', 'b'], 'y': [-1, -2]}, - dims=['x', 'y'], name='foobar', - attrs={'bar': 2}) + expected = DataArray( + data, + coords={"x": ["a", "b"], "y": [-1, -2]}, + dims=["x", "y"], + name="foobar", + attrs={"bar": 2}, + ) actual = DataArray(expected) assert_identical(expected, actual) actual = DataArray(expected.values, actual.coords) assert_equal(expected, actual) - frame = pd.DataFrame(data, index=pd.Index(['a', 'b'], name='x'), - columns=pd.Index([-1, -2], name='y')) + frame = pd.DataFrame( + data, + index=pd.Index(["a", "b"], name="x"), + columns=pd.Index([-1, -2], name="y"), + ) actual = DataArray(frame) assert_equal(expected, actual) - series = pd.Series(data[0], index=pd.Index([-1, -2], name='y')) + series = pd.Series(data[0], index=pd.Index([-1, -2], name="y")) actual = DataArray(series) - assert_equal(expected[0].reset_coords('x', drop=True), actual) + assert_equal(expected[0].reset_coords("x", drop=True), actual) - if LooseVersion(pd.__version__) < '0.25.0': + if LooseVersion(pd.__version__) < "0.25.0": with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + warnings.filterwarnings("ignore", r"\W*Panel is deprecated") panel = pd.Panel({0: frame}) actual = DataArray(panel) - expected = DataArray([data], expected.coords, ['dim_0', 'x', 'y']) - expected['dim_0'] = [0] + expected = DataArray([data], expected.coords, ["dim_0", "x", "y"]) + expected["dim_0"] = [0] assert_identical(expected, actual) - expected = DataArray(data, - coords={'x': ['a', 'b'], 'y': [-1, -2], - 'a': 0, 'z': ('x', [-0.5, 0.5])}, - dims=['x', 'y']) + expected = DataArray( + data, + coords={"x": ["a", "b"], "y": [-1, -2], "a": 0, "z": ("x", [-0.5, 0.5])}, + dims=["x", "y"], + ) actual = DataArray(expected) assert_identical(expected, actual) actual = DataArray(expected.values, expected.coords) assert_identical(expected, actual) - expected = Dataset({'foo': ('foo', ['a', 'b'])})['foo'] - actual = DataArray(pd.Index(['a', 'b'], name='foo')) + expected = Dataset({"foo": ("foo", ["a", "b"])})["foo"] + actual = DataArray(pd.Index(["a", "b"], name="foo")) assert_identical(expected, actual) - actual = DataArray(IndexVariable('foo', ['a', 'b'])) + actual = DataArray(IndexVariable("foo", ["a", "b"])) assert_identical(expected, actual) def test_constructor_from_0d(self): @@ -367,37 +405,35 @@ def test_constructor_dask_coords(self): coord = da.arange(8, chunks=(4,)) data = da.random.random((8, 8), chunks=(4, 4)) + 1 - actual = DataArray(data, coords={'x': coord, 'y': coord}, - dims=['x', 'y']) + actual = DataArray(data, coords={"x": coord, "y": coord}, dims=["x", "y"]) ecoord = np.arange(8) - expected = DataArray(data, coords={'x': ecoord, 'y': ecoord}, - dims=['x', 'y']) + expected = DataArray(data, coords={"x": ecoord, "y": ecoord}, dims=["x", "y"]) assert_equal(actual, expected) def test_equals_and_identical(self): - orig = DataArray(np.arange(5.0), {'a': 42}, dims='x') + orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") expected = orig actual = orig.copy() assert expected.equals(actual) assert expected.identical(actual) - actual = expected.rename('baz') + actual = expected.rename("baz") assert expected.equals(actual) assert not expected.identical(actual) - actual = expected.rename({'x': 'xxx'}) + actual = expected.rename({"x": "xxx"}) assert not expected.equals(actual) assert not expected.identical(actual) actual = expected.copy() - actual.attrs['foo'] = 'bar' + actual.attrs["foo"] = "bar" assert expected.equals(actual) assert not expected.identical(actual) actual = expected.copy() - actual['x'] = ('x', -np.arange(5)) + actual["x"] = ("x", -np.arange(5)) assert not expected.equals(actual) assert not expected.identical(actual) @@ -416,60 +452,83 @@ def test_equals_and_identical(self): assert not expected.identical(actual) actual = expected.copy() - actual['a'] = 100000 + actual["a"] = 100000 assert not expected.equals(actual) assert not expected.identical(actual) def test_equals_failures(self): - orig = DataArray(np.arange(5.0), {'a': 42}, dims='x') + orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") assert not orig.equals(np.arange(5)) assert not orig.identical(123) assert not orig.broadcast_equals({1: 2}) def test_broadcast_equals(self): - a = DataArray([0, 0], {'y': 0}, dims='x') - b = DataArray([0, 0], {'y': ('x', [0, 0])}, dims='x') + a = DataArray([0, 0], {"y": 0}, dims="x") + b = DataArray([0, 0], {"y": ("x", [0, 0])}, dims="x") assert a.broadcast_equals(b) assert b.broadcast_equals(a) assert not a.equals(b) assert not a.identical(b) - c = DataArray([0], coords={'x': 0}, dims='y') + c = DataArray([0], coords={"x": 0}, dims="y") assert not a.broadcast_equals(c) assert not c.broadcast_equals(a) def test_getitem(self): # strings pull out dataarrays - assert_identical(self.dv, self.ds['foo']) - x = self.dv['x'] - y = self.dv['y'] - assert_identical(self.ds['x'], x) - assert_identical(self.ds['y'], y) + assert_identical(self.dv, self.ds["foo"]) + x = self.dv["x"] + y = self.dv["y"] + assert_identical(self.ds["x"], x) + assert_identical(self.ds["y"], y) I = ReturnItem() # noqa - for i in [I[:], I[...], I[x.values], I[x.variable], I[x], I[x, y], - I[x.values > -1], I[x.variable > -1], I[x > -1], - I[x > -1, y > -1]]: + for i in [ + I[:], + I[...], + I[x.values], + I[x.variable], + I[x], + I[x, y], + I[x.values > -1], + I[x.variable > -1], + I[x > -1], + I[x > -1, y > -1], + ]: assert_equal(self.dv, self.dv[i]) - for i in [I[0], I[:, 0], I[:3, :2], - I[x.values[:3]], I[x.variable[:3]], - I[x[:3]], I[x[:3], y[:4]], - I[x.values > 3], I[x.variable > 3], - I[x > 3], I[x > 3, y > 3]]: + for i in [ + I[0], + I[:, 0], + I[:3, :2], + I[x.values[:3]], + I[x.variable[:3]], + I[x[:3]], + I[x[:3], y[:4]], + I[x.values > 3], + I[x.variable > 3], + I[x > 3], + I[x > 3, y > 3], + ]: assert_array_equal(self.v[i], self.dv[i]) def test_getitem_dict(self): - actual = self.dv[{'x': slice(3), 'y': 0}] + actual = self.dv[{"x": slice(3), "y": 0}] expected = self.dv.isel(x=slice(3), y=0) assert_identical(expected, actual) def test_getitem_coords(self): - orig = DataArray([[10], [20]], - {'x': [1, 2], 'y': [3], 'z': 4, - 'x2': ('x', ['a', 'b']), - 'y2': ('y', ['c']), - 'xy': (['y', 'x'], [['d', 'e']])}, - dims=['x', 'y']) + orig = DataArray( + [[10], [20]], + { + "x": [1, 2], + "y": [3], + "z": 4, + "x2": ("x", ["a", "b"]), + "y2": ("y", ["c"]), + "xy": (["y", "x"], [["d", "e"]]), + }, + dims=["x", "y"], + ) assert_identical(orig, orig[:]) assert_identical(orig, orig[:, :]) @@ -479,155 +538,198 @@ def test_getitem_coords(self): actual = orig[0, 0] expected = DataArray( - 10, {'x': 1, 'y': 3, 'z': 4, 'x2': 'a', 'y2': 'c', 'xy': 'd'}) + 10, {"x": 1, "y": 3, "z": 4, "x2": "a", "y2": "c", "xy": "d"} + ) assert_identical(expected, actual) actual = orig[0, :] expected = DataArray( - [10], {'x': 1, 'y': [3], 'z': 4, 'x2': 'a', 'y2': ('y', ['c']), - 'xy': ('y', ['d'])}, - dims='y') + [10], + { + "x": 1, + "y": [3], + "z": 4, + "x2": "a", + "y2": ("y", ["c"]), + "xy": ("y", ["d"]), + }, + dims="y", + ) assert_identical(expected, actual) actual = orig[:, 0] expected = DataArray( - [10, 20], {'x': [1, 2], 'y': 3, 'z': 4, 'x2': ('x', ['a', 'b']), - 'y2': 'c', 'xy': ('x', ['d', 'e'])}, - dims='x') + [10, 20], + { + "x": [1, 2], + "y": 3, + "z": 4, + "x2": ("x", ["a", "b"]), + "y2": "c", + "xy": ("x", ["d", "e"]), + }, + dims="x", + ) assert_identical(expected, actual) def test_getitem_dataarray(self): # It should not conflict - da = DataArray(np.arange(12).reshape((3, 4)), dims=['x', 'y']) - ind = DataArray([[0, 1], [0, 1]], dims=['x', 'z']) + da = DataArray(np.arange(12).reshape((3, 4)), dims=["x", "y"]) + ind = DataArray([[0, 1], [0, 1]], dims=["x", "z"]) actual = da[ind] assert_array_equal(actual, da.values[[[0, 1], [0, 1]], :]) - da = DataArray(np.arange(12).reshape((3, 4)), dims=['x', 'y'], - coords={'x': [0, 1, 2], 'y': ['a', 'b', 'c', 'd']}) - ind = xr.DataArray([[0, 1], [0, 1]], dims=['X', 'Y']) + da = DataArray( + np.arange(12).reshape((3, 4)), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": ["a", "b", "c", "d"]}, + ) + ind = xr.DataArray([[0, 1], [0, 1]], dims=["X", "Y"]) actual = da[ind] expected = da.values[[[0, 1], [0, 1]], :] assert_array_equal(actual, expected) - assert actual.dims == ('X', 'Y', 'y') + assert actual.dims == ("X", "Y", "y") # boolean indexing - ind = xr.DataArray([True, True, False], dims=['x']) + ind = xr.DataArray([True, True, False], dims=["x"]) assert_equal(da[ind], da[[0, 1], :]) assert_equal(da[ind], da[[0, 1]]) assert_equal(da[ind], da[ind.values]) def test_getitem_empty_index(self): - da = DataArray(np.arange(12).reshape((3, 4)), dims=['x', 'y']) - assert_identical(da[{'x': []}], - DataArray(np.zeros((0, 4)), dims=['x', 'y'])) - assert_identical(da.loc[{'y': []}], - DataArray(np.zeros((3, 0)), dims=['x', 'y'])) - assert_identical(da[[]], DataArray(np.zeros((0, 4)), dims=['x', 'y'])) + da = DataArray(np.arange(12).reshape((3, 4)), dims=["x", "y"]) + assert_identical(da[{"x": []}], DataArray(np.zeros((0, 4)), dims=["x", "y"])) + assert_identical( + da.loc[{"y": []}], DataArray(np.zeros((3, 0)), dims=["x", "y"]) + ) + assert_identical(da[[]], DataArray(np.zeros((0, 4)), dims=["x", "y"])) def test_setitem(self): # basic indexing should work as numpy's indexing - tuples = [(0, 0), (0, slice(None, None)), - (slice(None, None), slice(None, None)), - (slice(None, None), 0), - ([1, 0], slice(None, None)), - (slice(None, None), [1, 0])] + tuples = [ + (0, 0), + (0, slice(None, None)), + (slice(None, None), slice(None, None)), + (slice(None, None), 0), + ([1, 0], slice(None, None)), + (slice(None, None), [1, 0]), + ] for t in tuples: expected = np.arange(6).reshape(3, 2) - orig = DataArray(np.arange(6).reshape(3, 2), - {'x': [1, 2, 3], 'y': ['a', 'b'], 'z': 4, - 'x2': ('x', ['a', 'b', 'c']), - 'y2': ('y', ['d', 'e'])}, - dims=['x', 'y']) + orig = DataArray( + np.arange(6).reshape(3, 2), + { + "x": [1, 2, 3], + "y": ["a", "b"], + "z": 4, + "x2": ("x", ["a", "b", "c"]), + "y2": ("y", ["d", "e"]), + }, + dims=["x", "y"], + ) orig[t] = 1 expected[t] = 1 assert_array_equal(orig.values, expected) def test_setitem_fancy(self): # vectorized indexing - da = DataArray(np.ones((3, 2)), dims=['x', 'y']) - ind = Variable(['a'], [0, 1]) + da = DataArray(np.ones((3, 2)), dims=["x", "y"]) + ind = Variable(["a"], [0, 1]) da[dict(x=ind, y=ind)] = 0 - expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=['x', 'y']) + expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=["x", "y"]) assert_identical(expected, da) # assign another 0d-variable da[dict(x=ind, y=ind)] = Variable((), 0) - expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=['x', 'y']) + expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=["x", "y"]) assert_identical(expected, da) # assign another 1d-variable - da[dict(x=ind, y=ind)] = Variable(['a'], [2, 3]) - expected = DataArray([[2, 1], [1, 3], [1, 1]], dims=['x', 'y']) + da[dict(x=ind, y=ind)] = Variable(["a"], [2, 3]) + expected = DataArray([[2, 1], [1, 3], [1, 1]], dims=["x", "y"]) assert_identical(expected, da) # 2d-vectorized indexing - da = DataArray(np.ones((3, 2)), dims=['x', 'y']) - ind_x = DataArray([[0, 1]], dims=['a', 'b']) - ind_y = DataArray([[1, 0]], dims=['a', 'b']) + da = DataArray(np.ones((3, 2)), dims=["x", "y"]) + ind_x = DataArray([[0, 1]], dims=["a", "b"]) + ind_y = DataArray([[1, 0]], dims=["a", "b"]) da[dict(x=ind_x, y=ind_y)] = 0 - expected = DataArray([[1, 0], [0, 1], [1, 1]], dims=['x', 'y']) + expected = DataArray([[1, 0], [0, 1], [1, 1]], dims=["x", "y"]) assert_identical(expected, da) - da = DataArray(np.ones((3, 2)), dims=['x', 'y']) - ind = Variable(['a'], [0, 1]) + da = DataArray(np.ones((3, 2)), dims=["x", "y"]) + ind = Variable(["a"], [0, 1]) da[ind] = 0 - expected = DataArray([[0, 0], [0, 0], [1, 1]], dims=['x', 'y']) + expected = DataArray([[0, 0], [0, 0], [1, 1]], dims=["x", "y"]) assert_identical(expected, da) def test_setitem_dataarray(self): def get_data(): - return DataArray(np.ones((4, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': np.arange(4), 'y': ['a', 'b', 'c'], - 'non-dim': ('x', [1, 3, 4, 2])}) + return DataArray( + np.ones((4, 3, 2)), + dims=["x", "y", "z"], + coords={ + "x": np.arange(4), + "y": ["a", "b", "c"], + "non-dim": ("x", [1, 3, 4, 2]), + }, + ) da = get_data() # indexer with inconsistent coordinates. - ind = DataArray(np.arange(1, 4), dims=['x'], - coords={'x': np.random.randn(3)}) + ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.random.randn(3)}) with raises_regex(IndexError, "dimension coordinate 'x'"): da[dict(x=ind)] = 0 # indexer with consistent coordinates. - ind = DataArray(np.arange(1, 4), dims=['x'], - coords={'x': np.arange(1, 4)}) + ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.arange(1, 4)}) da[dict(x=ind)] = 0 # should not raise assert np.allclose(da[dict(x=ind)].values, 0) - assert_identical(da['x'], get_data()['x']) - assert_identical(da['non-dim'], get_data()['non-dim']) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) da = get_data() # conflict in the assigning values - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [0, 1, 2], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, + ) with raises_regex(IndexError, "dimension coordinate 'x'"): da[dict(x=ind)] = value # consistent coordinate in the assigning values - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [1, 2, 3], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) da[dict(x=ind)] = value assert np.allclose(da[dict(x=ind)].values, 0) - assert_identical(da['x'], get_data()['x']) - assert_identical(da['non-dim'], get_data()['non-dim']) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) # Conflict in the non-dimension coordinate - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [1, 2, 3], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) da[dict(x=ind)] = value # should not raise # conflict in the assigning values - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [0, 1, 2], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, + ) with raises_regex(IndexError, "dimension coordinate 'x'"): da[dict(x=ind)] = value # consistent coordinate in the assigning values - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [1, 2, 3], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) da[dict(x=ind)] = value # should not raise def test_contains(self): @@ -641,7 +743,7 @@ def test_attr_sources_multiindex(self): assert isinstance(self.mda.level_1, DataArray) def test_pickle(self): - data = DataArray(np.random.random((3, 3)), dims=('id', 'time')) + data = DataArray(np.random.random((3, 3)), dims=("id", "time")) roundtripped = pickle.loads(pickle.dumps(data)) assert_identical(data, roundtripped) @@ -663,9 +765,10 @@ def test_chunk(self): # Check that kwargs are passed import dask.array as da - blocked = unblocked.chunk(name_prefix='testname_') + + blocked = unblocked.chunk(name_prefix="testname_") assert isinstance(blocked.data, da.Array) - assert 'testname_' in blocked.data.name + assert "testname_" in blocked.data.name def test_isel(self): assert_identical(self.dv[0], self.dv.isel(x=0)) @@ -675,190 +778,196 @@ def test_isel(self): def test_isel_types(self): # regression test for #1405 - da = DataArray([1, 2, 3], dims='x') + da = DataArray([1, 2, 3], dims="x") # uint64 - assert_identical(da.isel(x=np.array([0], dtype="uint64")), - da.isel(x=np.array([0]))) + assert_identical( + da.isel(x=np.array([0], dtype="uint64")), da.isel(x=np.array([0])) + ) # uint32 - assert_identical(da.isel(x=np.array([0], dtype="uint32")), - da.isel(x=np.array([0]))) + assert_identical( + da.isel(x=np.array([0], dtype="uint32")), da.isel(x=np.array([0])) + ) # int64 - assert_identical(da.isel(x=np.array([0], dtype="int64")), - da.isel(x=np.array([0]))) + assert_identical( + da.isel(x=np.array([0], dtype="int64")), da.isel(x=np.array([0])) + ) - @pytest.mark.filterwarnings('ignore::DeprecationWarning') + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_isel_fancy(self): shape = (10, 7, 6) np_array = np.random.random(shape) - da = DataArray(np_array, dims=['time', 'y', 'x'], - coords={'time': np.arange(0, 100, 10)}) + da = DataArray( + np_array, dims=["time", "y", "x"], coords={"time": np.arange(0, 100, 10)} + ) y = [1, 3] x = [3, 0] expected = da.values[:, y, x] - actual = da.isel(y=(('test_coord', ), y), x=(('test_coord', ), x)) - assert actual.coords['test_coord'].shape == (len(y), ) - assert list(actual.coords) == ['time'] - assert actual.dims == ('time', 'test_coord') + actual = da.isel(y=(("test_coord",), y), x=(("test_coord",), x)) + assert actual.coords["test_coord"].shape == (len(y),) + assert list(actual.coords) == ["time"] + assert actual.dims == ("time", "test_coord") np.testing.assert_equal(actual, expected) # a few corner cases - da.isel(time=(('points',), [1, 2]), x=(('points',), [2, 2]), - y=(('points',), [3, 4])) + da.isel( + time=(("points",), [1, 2]), x=(("points",), [2, 2]), y=(("points",), [3, 4]) + ) np.testing.assert_allclose( - da.isel(time=(('p',), [1]), - x=(('p',), [2]), - y=(('p',), [4])).values.squeeze(), - np_array[1, 4, 2].squeeze()) - da.isel(time=(('points', ), [1, 2])) + da.isel( + time=(("p",), [1]), x=(("p",), [2]), y=(("p",), [4]) + ).values.squeeze(), + np_array[1, 4, 2].squeeze(), + ) + da.isel(time=(("points",), [1, 2])) y = [-1, 0] x = [-2, 2] expected = da.values[:, y, x] - actual = da.isel(x=(('points', ), x), y=(('points', ), y)).values + actual = da.isel(x=(("points",), x), y=(("points",), y)).values np.testing.assert_equal(actual, expected) # test that the order of the indexers doesn't matter assert_identical( - da.isel(y=(('points', ), y), x=(('points', ), x)), - da.isel(x=(('points', ), x), y=(('points', ), y))) + da.isel(y=(("points",), y), x=(("points",), x)), + da.isel(x=(("points",), x), y=(("points",), y)), + ) # make sure we're raising errors in the right places - with raises_regex(IndexError, - 'Dimensions of indexers mismatch'): - da.isel(y=(('points', ), [1, 2]), x=(('points', ), [1, 2, 3])) + with raises_regex(IndexError, "Dimensions of indexers mismatch"): + da.isel(y=(("points",), [1, 2]), x=(("points",), [1, 2, 3])) # tests using index or DataArray as indexers stations = Dataset() - stations['station'] = (('station', ), ['A', 'B', 'C']) - stations['dim1s'] = (('station', ), [1, 2, 3]) - stations['dim2s'] = (('station', ), [4, 5, 1]) + stations["station"] = (("station",), ["A", "B", "C"]) + stations["dim1s"] = (("station",), [1, 2, 3]) + stations["dim2s"] = (("station",), [4, 5, 1]) - actual = da.isel(x=stations['dim1s'], y=stations['dim2s']) - assert 'station' in actual.coords - assert 'station' in actual.dims - assert_identical(actual['station'], stations['station']) + actual = da.isel(x=stations["dim1s"], y=stations["dim2s"]) + assert "station" in actual.coords + assert "station" in actual.dims + assert_identical(actual["station"], stations["station"]) - with raises_regex(ValueError, 'conflicting values for '): - da.isel(x=DataArray([0, 1, 2], dims='station', - coords={'station': [0, 1, 2]}), - y=DataArray([0, 1, 2], dims='station', - coords={'station': [0, 1, 3]})) + with raises_regex(ValueError, "conflicting values for "): + da.isel( + x=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 2]}), + y=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 3]}), + ) # multi-dimensional selection stations = Dataset() - stations['a'] = (('a', ), ['A', 'B', 'C']) - stations['b'] = (('b', ), [0, 1]) - stations['dim1s'] = (('a', 'b'), [[1, 2], [2, 3], [3, 4]]) - stations['dim2s'] = (('a', ), [4, 5, 1]) - - actual = da.isel(x=stations['dim1s'], y=stations['dim2s']) - assert 'a' in actual.coords - assert 'a' in actual.dims - assert 'b' in actual.coords - assert 'b' in actual.dims - assert_identical(actual['a'], stations['a']) - assert_identical(actual['b'], stations['b']) - expected = da.variable[:, stations['dim2s'].variable, - stations['dim1s'].variable] + stations["a"] = (("a",), ["A", "B", "C"]) + stations["b"] = (("b",), [0, 1]) + stations["dim1s"] = (("a", "b"), [[1, 2], [2, 3], [3, 4]]) + stations["dim2s"] = (("a",), [4, 5, 1]) + + actual = da.isel(x=stations["dim1s"], y=stations["dim2s"]) + assert "a" in actual.coords + assert "a" in actual.dims + assert "b" in actual.coords + assert "b" in actual.dims + assert_identical(actual["a"], stations["a"]) + assert_identical(actual["b"], stations["b"]) + expected = da.variable[ + :, stations["dim2s"].variable, stations["dim1s"].variable + ] assert_array_equal(actual, expected) def test_sel(self): - self.ds['x'] = ('x', np.array(list('abcdefghij'))) - da = self.ds['foo'] + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] assert_identical(da, da.sel(x=slice(None))) - assert_identical(da[1], da.sel(x='b')) - assert_identical(da[:3], da.sel(x=slice('c'))) - assert_identical(da[:3], da.sel(x=['a', 'b', 'c'])) - assert_identical(da[:, :4], da.sel(y=(self.ds['y'] < 4))) + assert_identical(da[1], da.sel(x="b")) + assert_identical(da[:3], da.sel(x=slice("c"))) + assert_identical(da[:3], da.sel(x=["a", "b", "c"])) + assert_identical(da[:, :4], da.sel(y=(self.ds["y"] < 4))) # verify that indexing with a dataarray works - b = DataArray('b') + b = DataArray("b") assert_identical(da[1], da.sel(x=b)) assert_identical(da[[1]], da.sel(x=slice(b, b))) def test_sel_dataarray(self): # indexing with DataArray - self.ds['x'] = ('x', np.array(list('abcdefghij'))) - da = self.ds['foo'] + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] - ind = DataArray(['a', 'b', 'c'], dims=['x']) + ind = DataArray(["a", "b", "c"], dims=["x"]) actual = da.sel(x=ind) assert_identical(actual, da.isel(x=[0, 1, 2])) # along new dimension - ind = DataArray(['a', 'b', 'c'], dims=['new_dim']) + ind = DataArray(["a", "b", "c"], dims=["new_dim"]) actual = da.sel(x=ind) assert_array_equal(actual, da.isel(x=[0, 1, 2])) - assert 'new_dim' in actual.dims + assert "new_dim" in actual.dims # with coordinate - ind = DataArray(['a', 'b', 'c'], dims=['new_dim'], - coords={'new_dim': [0, 1, 2]}) + ind = DataArray( + ["a", "b", "c"], dims=["new_dim"], coords={"new_dim": [0, 1, 2]} + ) actual = da.sel(x=ind) assert_array_equal(actual, da.isel(x=[0, 1, 2])) - assert 'new_dim' in actual.dims - assert 'new_dim' in actual.coords - assert_equal(actual['new_dim'].drop('x'), ind['new_dim']) + assert "new_dim" in actual.dims + assert "new_dim" in actual.coords + assert_equal(actual["new_dim"].drop("x"), ind["new_dim"]) def test_sel_invalid_slice(self): - array = DataArray(np.arange(10), [('x', np.arange(10))]) - with raises_regex(ValueError, 'cannot use non-scalar arrays'): + array = DataArray(np.arange(10), [("x", np.arange(10))]) + with raises_regex(ValueError, "cannot use non-scalar arrays"): array.sel(x=slice(array.x)) def test_sel_dataarray_datetime(self): # regression test for GH1240 - times = pd.date_range('2000-01-01', freq='D', periods=365) - array = DataArray(np.arange(365), [('time', times)]) + times = pd.date_range("2000-01-01", freq="D", periods=365) + array = DataArray(np.arange(365), [("time", times)]) result = array.sel(time=slice(array.time[0], array.time[-1])) assert_equal(result, array) - array = DataArray(np.arange(365), [('delta', times - times[0])]) + array = DataArray(np.arange(365), [("delta", times - times[0])]) result = array.sel(delta=slice(array.delta[0], array.delta[-1])) assert_equal(result, array) def test_sel_no_index(self): - array = DataArray(np.arange(10), dims='x') + array = DataArray(np.arange(10), dims="x") assert_identical(array[0], array.sel(x=0)) assert_identical(array[:5], array.sel(x=slice(5))) assert_identical(array[[0, -1]], array.sel(x=[0, -1])) - assert_identical( - array[array < 5], array.sel(x=(array < 5))) + assert_identical(array[array < 5], array.sel(x=(array < 5))) def test_sel_method(self): - data = DataArray(np.random.randn(3, 4), - [('x', [0, 1, 2]), ('y', list('abcd'))]) + data = DataArray(np.random.randn(3, 4), [("x", [0, 1, 2]), ("y", list("abcd"))]) - expected = data.sel(y=['a', 'b']) - actual = data.sel(y=['ab', 'ba'], method='pad') + expected = data.sel(y=["a", "b"]) + actual = data.sel(y=["ab", "ba"], method="pad") assert_identical(expected, actual) expected = data.sel(x=[1, 2]) - actual = data.sel(x=[0.9, 1.9], method='backfill', tolerance=1) + actual = data.sel(x=[0.9, 1.9], method="backfill", tolerance=1) assert_identical(expected, actual) def test_sel_drop(self): - data = DataArray([1, 2, 3], [('x', [0, 1, 2])]) + data = DataArray([1, 2, 3], [("x", [0, 1, 2])]) expected = DataArray(1) selected = data.sel(x=0, drop=True) assert_identical(expected, selected) - expected = DataArray(1, {'x': 0}) + expected = DataArray(1, {"x": 0}) selected = data.sel(x=0, drop=False) assert_identical(expected, selected) - data = DataArray([1, 2, 3], dims=['x']) + data = DataArray([1, 2, 3], dims=["x"]) expected = DataArray(1) selected = data.sel(x=0, drop=True) assert_identical(expected, selected) def test_isel_drop(self): - data = DataArray([1, 2, 3], [('x', [0, 1, 2])]) + data = DataArray([1, 2, 3], [("x", [0, 1, 2])]) expected = DataArray(1) selected = data.isel(x=0, drop=True) assert_identical(expected, selected) - expected = DataArray(1, {'x': 0}) + expected = DataArray(1, {"x": 0}) selected = data.isel(x=0, drop=False) assert_identical(expected, selected) @@ -866,20 +975,21 @@ def test_isel_drop(self): def test_isel_points(self): shape = (10, 5, 6) np_array = np.random.random(shape) - da = DataArray(np_array, dims=['time', 'y', 'x'], - coords={'time': np.arange(0, 100, 10)}) + da = DataArray( + np_array, dims=["time", "y", "x"], coords={"time": np.arange(0, 100, 10)} + ) y = [1, 3] x = [3, 0] expected = da.values[:, y, x] - actual = da.isel_points(y=y, x=x, dim='test_coord') - assert actual.coords['test_coord'].shape == (len(y), ) - assert list(actual.coords) == ['time'] - assert actual.dims == ('test_coord', 'time') + actual = da.isel_points(y=y, x=x, dim="test_coord") + assert actual.coords["test_coord"].shape == (len(y),) + assert list(actual.coords) == ["time"] + assert actual.dims == ("test_coord", "time") actual = da.isel_points(y=y, x=x) - assert 'points' in actual.dims + assert "points" in actual.dims # Note that because xarray always concatenates along the first # dimension, We must transpose the result to match the numpy style of # concatenation. @@ -889,7 +999,8 @@ def test_isel_points(self): da.isel_points(time=[1, 2], x=[2, 2], y=[3, 4]) np.testing.assert_allclose( da.isel_points(time=[1], x=[2], y=[4]).values.squeeze(), - np_array[1, 4, 2].squeeze()) + np_array[1, 4, 2].squeeze(), + ) da.isel_points(time=[1, 2]) y = [-1, 0] x = [-2, 2] @@ -898,103 +1009,105 @@ def test_isel_points(self): np.testing.assert_equal(actual.T, expected) # test that the order of the indexers doesn't matter - assert_identical( - da.isel_points(y=y, x=x), - da.isel_points(x=x, y=y)) + assert_identical(da.isel_points(y=y, x=x), da.isel_points(x=x, y=y)) # make sure we're raising errors in the right places - with raises_regex(ValueError, - 'All indexers must be the same length'): + with raises_regex(ValueError, "All indexers must be the same length"): da.isel_points(y=[1, 2], x=[1, 2, 3]) - with raises_regex(ValueError, - 'dimension bad_key does not exist'): + with raises_regex(ValueError, "dimension bad_key does not exist"): da.isel_points(bad_key=[1, 2]) - with raises_regex(TypeError, 'Indexers must be integers'): + with raises_regex(TypeError, "Indexers must be integers"): da.isel_points(y=[1.5, 2.2]) - with raises_regex(TypeError, 'Indexers must be integers'): + with raises_regex(TypeError, "Indexers must be integers"): da.isel_points(x=[1, 2, 3], y=slice(3)) - with raises_regex(ValueError, - 'Indexers must be 1 dimensional'): + with raises_regex(ValueError, "Indexers must be 1 dimensional"): da.isel_points(y=1, x=2) - with raises_regex(ValueError, - 'Existing dimension names are not'): - da.isel_points(y=[1, 2], x=[1, 2], dim='x') + with raises_regex(ValueError, "Existing dimension names are not"): + da.isel_points(y=[1, 2], x=[1, 2], dim="x") # using non string dims - actual = da.isel_points(y=[1, 2], x=[1, 2], dim=['A', 'B']) - assert 'points' in actual.coords + actual = da.isel_points(y=[1, 2], x=[1, 2], dim=["A", "B"]) + assert "points" in actual.coords def test_loc(self): - self.ds['x'] = ('x', np.array(list('abcdefghij'))) - da = self.ds['foo'] - assert_identical(da[:3], da.loc[:'c']) - assert_identical(da[1], da.loc['b']) - assert_identical(da[1], da.loc[{'x': 'b'}]) - assert_identical(da[1], da.loc['b', ...]) - assert_identical(da[:3], da.loc[['a', 'b', 'c']]) - assert_identical(da[:3, :4], da.loc[['a', 'b', 'c'], np.arange(4)]) - assert_identical(da[:, :4], da.loc[:, self.ds['y'] < 4]) + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] + assert_identical(da[:3], da.loc[:"c"]) + assert_identical(da[1], da.loc["b"]) + assert_identical(da[1], da.loc[{"x": "b"}]) + assert_identical(da[1], da.loc["b", ...]) + assert_identical(da[:3], da.loc[["a", "b", "c"]]) + assert_identical(da[:3, :4], da.loc[["a", "b", "c"], np.arange(4)]) + assert_identical(da[:, :4], da.loc[:, self.ds["y"] < 4]) def test_loc_assign(self): - self.ds['x'] = ('x', np.array(list('abcdefghij'))) - da = self.ds['foo'] + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] # assignment - da.loc['a':'j'] = 0 + da.loc["a":"j"] = 0 assert np.all(da.values == 0) - da.loc[{'x': slice('a', 'j')}] = 2 + da.loc[{"x": slice("a", "j")}] = 2 assert np.all(da.values == 2) - da.loc[{'x': slice('a', 'j')}] = 2 + da.loc[{"x": slice("a", "j")}] = 2 assert np.all(da.values == 2) # Multi dimensional case - da = DataArray(np.arange(12).reshape(3, 4), dims=['x', 'y']) + da = DataArray(np.arange(12).reshape(3, 4), dims=["x", "y"]) da.loc[0, 0] = 0 assert da.values[0, 0] == 0 assert da.values[0, 1] != 0 - da = DataArray(np.arange(12).reshape(3, 4), dims=['x', 'y']) + da = DataArray(np.arange(12).reshape(3, 4), dims=["x", "y"]) da.loc[0] = 0 assert np.all(da.values[0] == np.zeros(4)) assert da.values[1, 0] != 0 def test_loc_assign_dataarray(self): def get_data(): - return DataArray(np.ones((4, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': np.arange(4), 'y': ['a', 'b', 'c'], - 'non-dim': ('x', [1, 3, 4, 2])}) + return DataArray( + np.ones((4, 3, 2)), + dims=["x", "y", "z"], + coords={ + "x": np.arange(4), + "y": ["a", "b", "c"], + "non-dim": ("x", [1, 3, 4, 2]), + }, + ) da = get_data() # indexer with inconsistent coordinates. - ind = DataArray(np.arange(1, 4), dims=['y'], - coords={'y': np.random.randn(3)}) + ind = DataArray(np.arange(1, 4), dims=["y"], coords={"y": np.random.randn(3)}) with raises_regex(IndexError, "dimension coordinate 'y'"): da.loc[dict(x=ind)] = 0 # indexer with consistent coordinates. - ind = DataArray(np.arange(1, 4), dims=['x'], - coords={'x': np.arange(1, 4)}) + ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.arange(1, 4)}) da.loc[dict(x=ind)] = 0 # should not raise assert np.allclose(da[dict(x=ind)].values, 0) - assert_identical(da['x'], get_data()['x']) - assert_identical(da['non-dim'], get_data()['non-dim']) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) da = get_data() # conflict in the assigning values - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [0, 1, 2], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, + ) with raises_regex(IndexError, "dimension coordinate 'x'"): da.loc[dict(x=ind)] = value # consistent coordinate in the assigning values - value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], - coords={'x': [1, 2, 3], - 'non-dim': ('x', [0, 2, 4])}) + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) da.loc[dict(x=ind)] = value assert np.allclose(da[dict(x=ind)].values, 0) - assert_identical(da['x'], get_data()['x']) - assert_identical(da['non-dim'], get_data()['non-dim']) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) def test_loc_single_boolean(self): data = DataArray([0, 1], coords=[[True, False]]) @@ -1002,12 +1115,12 @@ def test_loc_single_boolean(self): assert data.loc[False] == 1 def test_selection_multiindex(self): - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], - names=('one', 'two', 'three')) - mdata = DataArray(range(8), [('x', mindex)]) + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + mdata = DataArray(range(8), [("x", mindex)]) - def test_sel(lab_indexer, pos_indexer, replaced_idx=False, - renamed_dim=None): + def test_sel(lab_indexer, pos_indexer, replaced_idx=False, renamed_dim=None): da = mdata.sel(x=lab_indexer) expected_da = mdata.isel(x=pos_indexer) if not replaced_idx: @@ -1015,238 +1128,250 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, else: if renamed_dim: assert da.dims[0] == renamed_dim - da = da.rename({renamed_dim: 'x'}) + da = da.rename({renamed_dim: "x"}) assert_identical(da.variable, expected_da.variable) - assert not da['x'].equals(expected_da['x']) - - test_sel(('a', 1, -1), 0) - test_sel(('b', 2, -2), -1) - test_sel(('a', 1), [0, 1], replaced_idx=True, renamed_dim='three') - test_sel(('a',), range(4), replaced_idx=True) - test_sel('a', range(4), replaced_idx=True) - test_sel([('a', 1, -1), ('b', 2, -2)], [0, 7]) - test_sel(slice('a', 'b'), range(8)) - test_sel(slice(('a', 1), ('b', 1)), range(6)) - test_sel({'one': 'a', 'two': 1, 'three': -1}, 0) - test_sel({'one': 'a', 'two': 1}, [0, 1], replaced_idx=True, - renamed_dim='three') - test_sel({'one': 'a'}, range(4), replaced_idx=True) - - assert_identical(mdata.loc['a'], mdata.sel(x='a')) - assert_identical(mdata.loc[('a', 1), ...], mdata.sel(x=('a', 1))) - assert_identical(mdata.loc[{'one': 'a'}, ...], - mdata.sel(x={'one': 'a'})) + assert not da["x"].equals(expected_da["x"]) + + test_sel(("a", 1, -1), 0) + test_sel(("b", 2, -2), -1) + test_sel(("a", 1), [0, 1], replaced_idx=True, renamed_dim="three") + test_sel(("a",), range(4), replaced_idx=True) + test_sel("a", range(4), replaced_idx=True) + test_sel([("a", 1, -1), ("b", 2, -2)], [0, 7]) + test_sel(slice("a", "b"), range(8)) + test_sel(slice(("a", 1), ("b", 1)), range(6)) + test_sel({"one": "a", "two": 1, "three": -1}, 0) + test_sel({"one": "a", "two": 1}, [0, 1], replaced_idx=True, renamed_dim="three") + test_sel({"one": "a"}, range(4), replaced_idx=True) + + assert_identical(mdata.loc["a"], mdata.sel(x="a")) + assert_identical(mdata.loc[("a", 1), ...], mdata.sel(x=("a", 1))) + assert_identical(mdata.loc[{"one": "a"}, ...], mdata.sel(x={"one": "a"})) with pytest.raises(IndexError): - mdata.loc[('a', 1)] + mdata.loc[("a", 1)] - assert_identical(mdata.sel(x={'one': 'a', 'two': 1}), - mdata.sel(one='a', two=1)) + assert_identical(mdata.sel(x={"one": "a", "two": 1}), mdata.sel(one="a", two=1)) def test_selection_multiindex_remove_unused(self): # GH2619. For MultiIndex, we need to call remove_unused. - ds = xr.DataArray(np.arange(40).reshape(8, 5), dims=['x', 'y'], - coords={'x': np.arange(8), 'y': np.arange(5)}) - ds = ds.stack(xy=['x', 'y']) - ds_isel = ds.isel(xy=ds['x'] < 4) + ds = xr.DataArray( + np.arange(40).reshape(8, 5), + dims=["x", "y"], + coords={"x": np.arange(8), "y": np.arange(5)}, + ) + ds = ds.stack(xy=["x", "y"]) + ds_isel = ds.isel(xy=ds["x"] < 4) with pytest.raises(KeyError): ds_isel.sel(x=5) actual = ds_isel.unstack() - expected = ds.reset_index('xy').isel(xy=ds['x'] < 4) - expected = expected.set_index(xy=['x', 'y']).unstack() + expected = ds.reset_index("xy").isel(xy=ds["x"] < 4) + expected = expected.set_index(xy=["x", "y"]).unstack() assert_identical(expected, actual) def test_virtual_default_coords(self): - array = DataArray(np.zeros((5,)), dims='x') - expected = DataArray(range(5), dims='x', name='x') - assert_identical(expected, array['x']) - assert_identical(expected, array.coords['x']) + array = DataArray(np.zeros((5,)), dims="x") + expected = DataArray(range(5), dims="x", name="x") + assert_identical(expected, array["x"]) + assert_identical(expected, array.coords["x"]) def test_virtual_time_components(self): - dates = pd.date_range('2000-01-01', periods=10) - da = DataArray(np.arange(1, 11), [('time', dates)]) + dates = pd.date_range("2000-01-01", periods=10) + da = DataArray(np.arange(1, 11), [("time", dates)]) - assert_array_equal(da['time.dayofyear'], da.values) - assert_array_equal(da.coords['time.dayofyear'], da.values) + assert_array_equal(da["time.dayofyear"], da.values) + assert_array_equal(da.coords["time.dayofyear"], da.values) def test_coords(self): # use int64 to ensure repr() consistency on windows - coords = [IndexVariable('x', np.array([-1, -2], 'int64')), - IndexVariable('y', np.array([0, 1, 2], 'int64'))] - da = DataArray(np.random.randn(2, 3), coords, name='foo') + coords = [ + IndexVariable("x", np.array([-1, -2], "int64")), + IndexVariable("y", np.array([0, 1, 2], "int64")), + ] + da = DataArray(np.random.randn(2, 3), coords, name="foo") assert 2 == len(da.coords) - assert ['x', 'y'] == list(da.coords) + assert ["x", "y"] == list(da.coords) - assert coords[0].identical(da.coords['x']) - assert coords[1].identical(da.coords['y']) + assert coords[0].identical(da.coords["x"]) + assert coords[1].identical(da.coords["y"]) - assert 'x' in da.coords + assert "x" in da.coords assert 0 not in da.coords - assert 'foo' not in da.coords + assert "foo" not in da.coords with pytest.raises(KeyError): da.coords[0] with pytest.raises(KeyError): - da.coords['foo'] + da.coords["foo"] - expected = dedent("""\ + expected = dedent( + """\ Coordinates: * x (x) int64 -1 -2 - * y (y) int64 0 1 2""") + * y (y) int64 0 1 2""" + ) actual = repr(da.coords) assert expected == actual - del da.coords['x'] - expected = DataArray(da.values, {'y': [0, 1, 2]}, dims=['x', 'y'], - name='foo') + del da.coords["x"] + expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) - with raises_regex(ValueError, 'conflicting MultiIndex'): - self.mda['level_1'] = np.arange(4) - self.mda.coords['level_1'] = np.arange(4) + with raises_regex(ValueError, "conflicting MultiIndex"): + self.mda["level_1"] = np.arange(4) + self.mda.coords["level_1"] = np.arange(4) def test_coords_to_index(self): - da = DataArray(np.zeros((2, 3)), [('x', [1, 2]), ('y', list('abc'))]) + da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) - with raises_regex(ValueError, 'no valid index'): + with raises_regex(ValueError, "no valid index"): da[0, 0].coords.to_index() - expected = pd.Index(['a', 'b', 'c'], name='y') + expected = pd.Index(["a", "b", "c"], name="y") actual = da[0].coords.to_index() assert expected.equals(actual) - expected = pd.MultiIndex.from_product([[1, 2], ['a', 'b', 'c']], - names=['x', 'y']) + expected = pd.MultiIndex.from_product( + [[1, 2], ["a", "b", "c"]], names=["x", "y"] + ) actual = da.coords.to_index() assert expected.equals(actual) - expected = pd.MultiIndex.from_product([['a', 'b', 'c'], [1, 2]], - names=['y', 'x']) - actual = da.coords.to_index(['y', 'x']) + expected = pd.MultiIndex.from_product( + [["a", "b", "c"], [1, 2]], names=["y", "x"] + ) + actual = da.coords.to_index(["y", "x"]) assert expected.equals(actual) - with raises_regex(ValueError, 'ordered_dims must match'): - da.coords.to_index(['x']) + with raises_regex(ValueError, "ordered_dims must match"): + da.coords.to_index(["x"]) def test_coord_coords(self): - orig = DataArray([10, 20], - {'x': [1, 2], 'x2': ('x', ['a', 'b']), 'z': 4}, - dims='x') - - actual = orig.coords['x'] - expected = DataArray([1, 2], {'z': 4, 'x2': ('x', ['a', 'b']), - 'x': [1, 2]}, - dims='x', name='x') + orig = DataArray( + [10, 20], {"x": [1, 2], "x2": ("x", ["a", "b"]), "z": 4}, dims="x" + ) + + actual = orig.coords["x"] + expected = DataArray( + [1, 2], {"z": 4, "x2": ("x", ["a", "b"]), "x": [1, 2]}, dims="x", name="x" + ) assert_identical(expected, actual) - del actual.coords['x2'] - assert_identical( - expected.reset_coords('x2', drop=True), actual) + del actual.coords["x2"] + assert_identical(expected.reset_coords("x2", drop=True), actual) - actual.coords['x3'] = ('x', ['a', 'b']) - expected = DataArray([1, 2], {'z': 4, 'x3': ('x', ['a', 'b']), - 'x': [1, 2]}, - dims='x', name='x') + actual.coords["x3"] = ("x", ["a", "b"]) + expected = DataArray( + [1, 2], {"z": 4, "x3": ("x", ["a", "b"]), "x": [1, 2]}, dims="x", name="x" + ) assert_identical(expected, actual) def test_reset_coords(self): - data = DataArray(np.zeros((3, 4)), - {'bar': ('x', ['a', 'b', 'c']), - 'baz': ('y', range(4)), - 'y': range(4)}, - dims=['x', 'y'], - name='foo') + data = DataArray( + np.zeros((3, 4)), + {"bar": ("x", ["a", "b", "c"]), "baz": ("y", range(4)), "y": range(4)}, + dims=["x", "y"], + name="foo", + ) actual = data.reset_coords() - expected = Dataset({'foo': (['x', 'y'], np.zeros((3, 4))), - 'bar': ('x', ['a', 'b', 'c']), - 'baz': ('y', range(4)), - 'y': range(4)}) + expected = Dataset( + { + "foo": (["x", "y"], np.zeros((3, 4))), + "bar": ("x", ["a", "b", "c"]), + "baz": ("y", range(4)), + "y": range(4), + } + ) assert_identical(actual, expected) - actual = data.reset_coords(['bar', 'baz']) + actual = data.reset_coords(["bar", "baz"]) assert_identical(actual, expected) - actual = data.reset_coords('bar') - expected = Dataset({'foo': (['x', 'y'], np.zeros((3, 4))), - 'bar': ('x', ['a', 'b', 'c'])}, - {'baz': ('y', range(4)), 'y': range(4)}) + actual = data.reset_coords("bar") + expected = Dataset( + {"foo": (["x", "y"], np.zeros((3, 4))), "bar": ("x", ["a", "b", "c"])}, + {"baz": ("y", range(4)), "y": range(4)}, + ) assert_identical(actual, expected) - actual = data.reset_coords(['bar']) + actual = data.reset_coords(["bar"]) assert_identical(actual, expected) actual = data.reset_coords(drop=True) - expected = DataArray(np.zeros((3, 4)), coords={'y': range(4)}, - dims=['x', 'y'], name='foo') + expected = DataArray( + np.zeros((3, 4)), coords={"y": range(4)}, dims=["x", "y"], name="foo" + ) assert_identical(actual, expected) actual = data.copy() actual = actual.reset_coords(drop=True) assert_identical(actual, expected) - actual = data.reset_coords('bar', drop=True) - expected = DataArray(np.zeros((3, 4)), - {'baz': ('y', range(4)), 'y': range(4)}, - dims=['x', 'y'], name='foo') + actual = data.reset_coords("bar", drop=True) + expected = DataArray( + np.zeros((3, 4)), + {"baz": ("y", range(4)), "y": range(4)}, + dims=["x", "y"], + name="foo", + ) assert_identical(actual, expected) - with pytest.warns(FutureWarning, match='The inplace argument'): - with raises_regex(ValueError, 'cannot reset coord'): + with pytest.warns(FutureWarning, match="The inplace argument"): + with raises_regex(ValueError, "cannot reset coord"): data = data.reset_coords(inplace=True) - with raises_regex(ValueError, 'cannot be found'): - data.reset_coords('foo', drop=True) - with raises_regex(ValueError, 'cannot be found'): - data.reset_coords('not_found') - with raises_regex(ValueError, 'cannot remove index'): - data.reset_coords('y') + with raises_regex(ValueError, "cannot be found"): + data.reset_coords("foo", drop=True) + with raises_regex(ValueError, "cannot be found"): + data.reset_coords("not_found") + with raises_regex(ValueError, "cannot remove index"): + data.reset_coords("y") def test_assign_coords(self): array = DataArray(10) actual = array.assign_coords(c=42) - expected = DataArray(10, {'c': 42}) + expected = DataArray(10, {"c": 42}) assert_identical(actual, expected) - array = DataArray([1, 2, 3, 4], {'c': ('x', [0, 0, 1, 1])}, dims='x') - actual = array.groupby('c').assign_coords(d=lambda a: a.mean()) + array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") + actual = array.groupby("c").assign_coords(d=lambda a: a.mean()) expected = array.copy() - expected.coords['d'] = ('x', [1.5, 1.5, 3.5, 3.5]) + expected.coords["d"] = ("x", [1.5, 1.5, 3.5, 3.5]) assert_identical(actual, expected) - with raises_regex(ValueError, 'conflicting MultiIndex'): + with raises_regex(ValueError, "conflicting MultiIndex"): self.mda.assign_coords(level_1=range(4)) # GH: 2112 - da = xr.DataArray([0, 1, 2], dims='x') + da = xr.DataArray([0, 1, 2], dims="x") with pytest.raises(ValueError): - da['x'] = [0, 1, 2, 3] # size conflict + da["x"] = [0, 1, 2, 3] # size conflict with pytest.raises(ValueError): - da.coords['x'] = [0, 1, 2, 3] # size conflict + da.coords["x"] = [0, 1, 2, 3] # size conflict def test_coords_alignment(self): - lhs = DataArray([1, 2, 3], [('x', [0, 1, 2])]) - rhs = DataArray([2, 3, 4], [('x', [1, 2, 3])]) - lhs.coords['rhs'] = rhs - - expected = DataArray([1, 2, 3], - coords={'rhs': ('x', [np.nan, 2, 3]), - 'x': [0, 1, 2]}, - dims='x') + lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) + rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) + lhs.coords["rhs"] = rhs + + expected = DataArray( + [1, 2, 3], coords={"rhs": ("x", [np.nan, 2, 3]), "x": [0, 1, 2]}, dims="x" + ) assert_identical(lhs, expected) def test_set_coords_update_index(self): - actual = DataArray([1, 2, 3], [('x', [1, 2, 3])]) - actual.coords['x'] = ['a', 'b', 'c'] - assert actual.indexes['x'].equals(pd.Index(['a', 'b', 'c'])) + actual = DataArray([1, 2, 3], [("x", [1, 2, 3])]) + actual.coords["x"] = ["a", "b", "c"] + assert actual.indexes["x"].equals(pd.Index(["a", "b", "c"])) def test_coords_replacement_alignment(self): # regression test for GH725 - arr = DataArray([0, 1, 2], dims=['abc']) - new_coord = DataArray([1, 2, 3], dims=['abc'], coords=[[1, 2, 3]]) - arr['abc'] = new_coord - expected = DataArray([0, 1, 2], coords=[('abc', [1, 2, 3])]) + arr = DataArray([0, 1, 2], dims=["abc"]) + new_coord = DataArray([1, 2, 3], dims=["abc"], coords=[[1, 2, 3]]) + arr["abc"] = new_coord + expected = DataArray([0, 1, 2], coords=[("abc", [1, 2, 3])]) assert_identical(arr, expected) def test_coords_non_string(self): @@ -1256,10 +1381,16 @@ def test_coords_non_string(self): assert_identical(actual, expected) def test_broadcast_like(self): - arr1 = DataArray(np.ones((2, 3)), dims=['x', 'y'], - coords={'x': ['a', 'b'], 'y': ['a', 'b', 'c']}) - arr2 = DataArray(np.ones((3, 2)), dims=['x', 'y'], - coords={'x': ['a', 'b', 'c'], 'y': ['a', 'b']}) + arr1 = DataArray( + np.ones((2, 3)), + dims=["x", "y"], + coords={"x": ["a", "b"], "y": ["a", "b", "c"]}, + ) + arr2 = DataArray( + np.ones((3, 2)), + dims=["x", "y"], + coords={"x": ["a", "b", "c"], "y": ["a", "b"]}, + ) orig1, orig2 = broadcast(arr1, arr2) new1 = arr1.broadcast_like(arr2) new2 = arr2.broadcast_like(arr1) @@ -1267,16 +1398,15 @@ def test_broadcast_like(self): assert orig1.identical(new1) assert orig2.identical(new2) - orig3 = DataArray(np.random.randn(5), [('x', range(5))]) - orig4 = DataArray(np.random.randn(6), [('y', range(6))]) + orig3 = DataArray(np.random.randn(5), [("x", range(5))]) + orig4 = DataArray(np.random.randn(6), [("y", range(6))]) new3, new4 = broadcast(orig3, orig4) - assert_identical(orig3.broadcast_like(orig4), new3.transpose('y', 'x')) + assert_identical(orig3.broadcast_like(orig4), new3.transpose("y", "x")) assert_identical(orig4.broadcast_like(orig3), new4) def test_reindex_like(self): - foo = DataArray(np.random.randn(5, 6), - [('x', range(5)), ('y', range(6))]) + foo = DataArray(np.random.randn(5, 6), [("x", range(5)), ("y", range(6))]) bar = foo[:2, :2] assert_identical(foo.reindex_like(bar), bar) @@ -1286,15 +1416,14 @@ def test_reindex_like(self): assert_identical(bar.reindex_like(foo), expected) def test_reindex_like_no_index(self): - foo = DataArray(np.random.randn(5, 6), dims=['x', 'y']) + foo = DataArray(np.random.randn(5, 6), dims=["x", "y"]) assert_identical(foo, foo.reindex_like(foo)) bar = foo[:4] - with raises_regex( - ValueError, 'different size for unlabeled'): + with raises_regex(ValueError, "different size for unlabeled"): foo.reindex_like(bar) - @pytest.mark.filterwarnings('ignore:Indexer has dimensions') + @pytest.mark.filterwarnings("ignore:Indexer has dimensions") def test_reindex_regressions(self): # regression test for #279 expected = DataArray(np.random.randn(5), coords=[("time", range(5))]) @@ -1306,87 +1435,91 @@ def test_reindex_regressions(self): x = np.array([1, 2, 3], dtype=np.complex) x = DataArray(x, coords=[[0.1, 0.2, 0.3]]) y = DataArray([2, 5, 6, 7, 8], coords=[[-1.1, 0.21, 0.31, 0.41, 0.51]]) - re_dtype = x.reindex_like(y, method='pad').dtype + re_dtype = x.reindex_like(y, method="pad").dtype assert x.dtype == re_dtype def test_reindex_method(self): - x = DataArray([10, 20], dims='y', coords={'y': [0, 1]}) + x = DataArray([10, 20], dims="y", coords={"y": [0, 1]}) y = [-0.1, 0.5, 1.1] - actual = x.reindex(y=y, method='backfill', tolerance=0.2) - expected = DataArray([10, np.nan, np.nan], coords=[('y', y)]) + actual = x.reindex(y=y, method="backfill", tolerance=0.2) + expected = DataArray([10, np.nan, np.nan], coords=[("y", y)]) assert_identical(expected, actual) - alt = Dataset({'y': y}) - actual = x.reindex_like(alt, method='backfill') - expected = DataArray([10, 20, np.nan], coords=[('y', y)]) + alt = Dataset({"y": y}) + actual = x.reindex_like(alt, method="backfill") + expected = DataArray([10, 20, np.nan], coords=[("y", y)]) assert_identical(expected, actual) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_reindex_fill_value(self, fill_value): - x = DataArray([10, 20], dims='y', coords={'y': [0, 1]}) + x = DataArray([10, 20], dims="y", coords={"y": [0, 1]}) y = [0, 1, 2] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan actual = x.reindex(y=y, fill_value=fill_value) - expected = DataArray([10, 20, fill_value], coords=[('y', y)]) + expected = DataArray([10, 20, fill_value], coords=[("y", y)]) assert_identical(expected, actual) def test_rename(self): - renamed = self.dv.rename('bar') - assert_identical( - renamed.to_dataset(), self.ds.rename({'foo': 'bar'})) - assert renamed.name == 'bar' + renamed = self.dv.rename("bar") + assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"})) + assert renamed.name == "bar" - renamed = self.dv.x.rename({'x': 'z'}).rename('z') - assert_identical( - renamed, self.ds.rename({'x': 'z'}).z) - assert renamed.name == 'z' - assert renamed.dims == ('z',) + renamed = self.dv.x.rename({"x": "z"}).rename("z") + assert_identical(renamed, self.ds.rename({"x": "z"}).z) + assert renamed.name == "z" + assert renamed.dims == ("z",) - renamed_kwargs = self.dv.x.rename(x='z').rename('z') + renamed_kwargs = self.dv.x.rename(x="z").rename("z") assert_identical(renamed, renamed_kwargs) def test_swap_dims(self): - array = DataArray(np.random.randn(3), {'y': ('x', list('abc'))}, 'x') - expected = DataArray(array.values, {'y': list('abc')}, dims='y') - actual = array.swap_dims({'x': 'y'}) + array = DataArray(np.random.randn(3), {"y": ("x", list("abc"))}, "x") + expected = DataArray(array.values, {"y": list("abc")}, dims="y") + actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) def test_expand_dims_error(self): - array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) - with raises_regex(TypeError, 'dim should be hashable or'): + with raises_regex(TypeError, "dim should be hashable or"): array.expand_dims(0) - with raises_regex(ValueError, 'lengths of dim and axis'): + with raises_regex(ValueError, "lengths of dim and axis"): # dims and axis argument should be the same length - array.expand_dims(dim=['a', 'b'], axis=[1, 2, 3]) - with raises_regex(ValueError, 'Dimension x already'): + array.expand_dims(dim=["a", "b"], axis=[1, 2, 3]) + with raises_regex(ValueError, "Dimension x already"): # Should not pass the already existing dimension. - array.expand_dims(dim=['x']) + array.expand_dims(dim=["x"]) # raise if duplicate - with raises_regex(ValueError, 'duplicate values.'): - array.expand_dims(dim=['y', 'y']) - with raises_regex(ValueError, 'duplicate values.'): - array.expand_dims(dim=['y', 'z'], axis=[1, 1]) - with raises_regex(ValueError, 'duplicate values.'): - array.expand_dims(dim=['y', 'z'], axis=[2, -2]) + with raises_regex(ValueError, "duplicate values."): + array.expand_dims(dim=["y", "y"]) + with raises_regex(ValueError, "duplicate values."): + array.expand_dims(dim=["y", "z"], axis=[1, 1]) + with raises_regex(ValueError, "duplicate values."): + array.expand_dims(dim=["y", "z"], axis=[2, -2]) # out of bounds error, axis must be in [-4, 3] with pytest.raises(IndexError): - array.expand_dims(dim=['y', 'z'], axis=[2, 4]) + array.expand_dims(dim=["y", "z"], axis=[2, 4]) with pytest.raises(IndexError): - array.expand_dims(dim=['y', 'z'], axis=[2, -5]) + array.expand_dims(dim=["y", "z"], axis=[2, -5]) # Does not raise an IndexError - array.expand_dims(dim=['y', 'z'], axis=[2, -4]) - array.expand_dims(dim=['y', 'z'], axis=[2, 3]) + array.expand_dims(dim=["y", "z"], axis=[2, -4]) + array.expand_dims(dim=["y", "z"], axis=[2, 3]) - array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) with pytest.raises(TypeError): array.expand_dims(OrderedDict((("new_dim", 3.2),))) @@ -1395,86 +1528,109 @@ def test_expand_dims_error(self): array.expand_dims(OrderedDict((("d", 4),)), e=4) def test_expand_dims(self): - array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) # pass only dim label - actual = array.expand_dims(dim='y') - expected = DataArray(np.expand_dims(array.values, 0), - dims=['y', 'x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + actual = array.expand_dims(dim="y") + expected = DataArray( + np.expand_dims(array.values, 0), + dims=["y", "x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) - roundtripped = actual.squeeze('y', drop=True) + roundtripped = actual.squeeze("y", drop=True) assert_identical(array, roundtripped) # pass multiple dims - actual = array.expand_dims(dim=['y', 'z']) - expected = DataArray(np.expand_dims(np.expand_dims(array.values, 0), - 0), - dims=['y', 'z', 'x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + actual = array.expand_dims(dim=["y", "z"]) + expected = DataArray( + np.expand_dims(np.expand_dims(array.values, 0), 0), + dims=["y", "z", "x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) - roundtripped = actual.squeeze(['y', 'z'], drop=True) + roundtripped = actual.squeeze(["y", "z"], drop=True) assert_identical(array, roundtripped) # pass multiple dims and axis. Axis is out of order - actual = array.expand_dims(dim=['z', 'y'], axis=[2, 1]) - expected = DataArray(np.expand_dims(np.expand_dims(array.values, 1), - 2), - dims=['x', 'y', 'z', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + actual = array.expand_dims(dim=["z", "y"], axis=[2, 1]) + expected = DataArray( + np.expand_dims(np.expand_dims(array.values, 1), 2), + dims=["x", "y", "z", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) # make sure the attrs are tracked - assert actual.attrs['key'] == 'entry' - roundtripped = actual.squeeze(['z', 'y'], drop=True) + assert actual.attrs["key"] == "entry" + roundtripped = actual.squeeze(["z", "y"], drop=True) assert_identical(array, roundtripped) # Negative axis and they are out of order - actual = array.expand_dims(dim=['y', 'z'], axis=[-1, -2]) - expected = DataArray(np.expand_dims(np.expand_dims(array.values, -1), - -1), - dims=['x', 'dim_0', 'z', 'y'], - coords={'x': np.linspace(0.0, 1.0, 3)}, - attrs={'key': 'entry'}) + actual = array.expand_dims(dim=["y", "z"], axis=[-1, -2]) + expected = DataArray( + np.expand_dims(np.expand_dims(array.values, -1), -1), + dims=["x", "dim_0", "z", "y"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) - assert actual.attrs['key'] == 'entry' - roundtripped = actual.squeeze(['y', 'z'], drop=True) + assert actual.attrs["key"] == "entry" + roundtripped = actual.squeeze(["y", "z"], drop=True) assert_identical(array, roundtripped) def test_expand_dims_with_scalar_coordinate(self): - array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0}, - attrs={'key': 'entry'}) - actual = array.expand_dims(dim='z') - expected = DataArray(np.expand_dims(array.values, 0), - dims=['z', 'x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3), - 'z': np.ones(1)}, - attrs={'key': 'entry'}) + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3), "z": 1.0}, + attrs={"key": "entry"}, + ) + actual = array.expand_dims(dim="z") + expected = DataArray( + np.expand_dims(array.values, 0), + dims=["z", "x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3), "z": np.ones(1)}, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) - roundtripped = actual.squeeze(['z'], drop=False) + roundtripped = actual.squeeze(["z"], drop=False) assert_identical(array, roundtripped) def test_expand_dims_with_greater_dim_size(self): - array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0}, - attrs={'key': 'entry'}) + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3), "z": 1.0}, + attrs={"key": "entry"}, + ) # For python 3.5 and earlier this has to be an ordered dict, to # maintain insertion order. actual = array.expand_dims( - OrderedDict((('y', 2), ('z', 1), ('dim_1', ['a', 'b', 'c'])))) - - expected_coords = OrderedDict(( - ('y', [0, 1]), ('z', [1.0]), ('dim_1', ['a', 'b', 'c']), - ('x', np.linspace(0, 1, 3)), ('dim_0', range(4)))) - expected = DataArray(array.values * np.ones([2, 1, 3, 3, 4]), - coords=expected_coords, - dims=list(expected_coords.keys()), - attrs={'key': 'entry'} - ).drop(['y', 'dim_0']) + OrderedDict((("y", 2), ("z", 1), ("dim_1", ["a", "b", "c"]))) + ) + + expected_coords = OrderedDict( + ( + ("y", [0, 1]), + ("z", [1.0]), + ("dim_1", ["a", "b", "c"]), + ("x", np.linspace(0, 1, 3)), + ("dim_0", range(4)), + ) + ) + expected = DataArray( + array.values * np.ones([2, 1, 3, 3, 4]), + coords=expected_coords, + dims=list(expected_coords.keys()), + attrs={"key": "entry"}, + ).drop(["y", "dim_0"]) assert_identical(expected, actual) # Test with kwargs instead of passing dict to dim arg. @@ -1483,15 +1639,19 @@ def test_expand_dims_with_greater_dim_size(self): # is no longer supported. python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 if python36_plus: - other_way = array.expand_dims(dim_1=['a', 'b', 'c']) + other_way = array.expand_dims(dim_1=["a", "b", "c"]) other_way_expected = DataArray( array.values * np.ones([3, 3, 4]), - coords={'dim_1': ['a', 'b', 'c'], - 'x': np.linspace(0, 1, 3), - 'dim_0': range(4), 'z': 1.0}, - dims=['dim_1', 'x', 'dim_0'], - attrs={'key': 'entry'}).drop('dim_0') + coords={ + "dim_1": ["a", "b", "c"], + "x": np.linspace(0, 1, 3), + "dim_0": range(4), + "z": 1.0, + }, + dims=["dim_1", "x", "dim_0"], + attrs={"key": "entry"}, + ).drop("dim_0") assert_identical(other_way_expected, other_way) else: # In python 3.5, using dim_kwargs should raise a ValueError. @@ -1500,84 +1660,86 @@ def test_expand_dims_with_greater_dim_size(self): def test_set_index(self): indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] - coords = {idx.name: ('x', idx) for idx in indexes} - array = DataArray(self.mda.values, coords=coords, dims='x') + coords = {idx.name: ("x", idx) for idx in indexes} + array = DataArray(self.mda.values, coords=coords, dims="x") expected = self.mda.copy() - level_3 = ('x', [1, 2, 3, 4]) - array['level_3'] = level_3 - expected['level_3'] = level_3 + level_3 = ("x", [1, 2, 3, 4]) + array["level_3"] = level_3 + expected["level_3"] = level_3 obj = array.set_index(x=self.mindex.names) assert_identical(obj, expected) - obj = obj.set_index(x='level_3', append=True) - expected = array.set_index(x=['level_1', 'level_2', 'level_3']) + obj = obj.set_index(x="level_3", append=True) + expected = array.set_index(x=["level_1", "level_2", "level_3"]) assert_identical(obj, expected) - array = array.set_index(x=['level_1', 'level_2', 'level_3']) + array = array.set_index(x=["level_1", "level_2", "level_3"]) assert_identical(array, expected) - array2d = DataArray(np.random.rand(2, 2), - coords={'x': ('x', [0, 1]), - 'level': ('y', [1, 2])}, - dims=('x', 'y')) - with raises_regex(ValueError, 'dimension mismatch'): - array2d.set_index(x='level') + array2d = DataArray( + np.random.rand(2, 2), + coords={"x": ("x", [0, 1]), "level": ("y", [1, 2])}, + dims=("x", "y"), + ) + with raises_regex(ValueError, "dimension mismatch"): + array2d.set_index(x="level") def test_reset_index(self): indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] - coords = {idx.name: ('x', idx) for idx in indexes} - expected = DataArray(self.mda.values, coords=coords, dims='x') + coords = {idx.name: ("x", idx) for idx in indexes} + expected = DataArray(self.mda.values, coords=coords, dims="x") - obj = self.mda.reset_index('x') + obj = self.mda.reset_index("x") assert_identical(obj, expected) obj = self.mda.reset_index(self.mindex.names) assert_identical(obj, expected) - obj = self.mda.reset_index(['x', 'level_1']) + obj = self.mda.reset_index(["x", "level_1"]) assert_identical(obj, expected) - coords = {'x': ('x', self.mindex.droplevel('level_1')), - 'level_1': ('x', self.mindex.get_level_values('level_1'))} - expected = DataArray(self.mda.values, coords=coords, dims='x') - obj = self.mda.reset_index(['level_1']) + coords = { + "x": ("x", self.mindex.droplevel("level_1")), + "level_1": ("x", self.mindex.get_level_values("level_1")), + } + expected = DataArray(self.mda.values, coords=coords, dims="x") + obj = self.mda.reset_index(["level_1"]) assert_identical(obj, expected) - expected = DataArray(self.mda.values, dims='x') - obj = self.mda.reset_index('x', drop=True) + expected = DataArray(self.mda.values, dims="x") + obj = self.mda.reset_index("x", drop=True) assert_identical(obj, expected) array = self.mda.copy() - array = array.reset_index(['x'], drop=True) + array = array.reset_index(["x"], drop=True) assert_identical(array, expected) # single index - array = DataArray([1, 2], coords={'x': ['a', 'b']}, dims='x') - expected = DataArray([1, 2], coords={'x_': ('x', ['a', 'b'])}, - dims='x') - assert_identical(array.reset_index('x'), expected) + array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x") + expected = DataArray([1, 2], coords={"x_": ("x", ["a", "b"])}, dims="x") + assert_identical(array.reset_index("x"), expected) def test_reorder_levels(self): - midx = self.mindex.reorder_levels(['level_2', 'level_1']) - expected = DataArray(self.mda.values, coords={'x': midx}, dims='x') + midx = self.mindex.reorder_levels(["level_2", "level_1"]) + expected = DataArray(self.mda.values, coords={"x": midx}, dims="x") - obj = self.mda.reorder_levels(x=['level_2', 'level_1']) + obj = self.mda.reorder_levels(x=["level_2", "level_1"]) assert_identical(obj, expected) - with pytest.warns(FutureWarning, match='The inplace argument'): + with pytest.warns(FutureWarning, match="The inplace argument"): array = self.mda.copy() - array.reorder_levels(x=['level_2', 'level_1'], inplace=True) + array.reorder_levels(x=["level_2", "level_1"], inplace=True) assert_identical(array, expected) - array = DataArray([1, 2], dims='x') + array = DataArray([1, 2], dims="x") with pytest.raises(KeyError): - array.reorder_levels(x=['level_1', 'level_2']) + array.reorder_levels(x=["level_1", "level_2"]) - array['x'] = [0, 1] - with raises_regex(ValueError, 'has no MultiIndex'): - array.reorder_levels(x=['level_1', 'level_2']) + array["x"] = [0, 1] + with raises_regex(ValueError, "has no MultiIndex"): + array.reorder_levels(x=["level_1", "level_2"]) def test_dataset_getitem(self): - dv = self.ds['foo'] + dv = self.ds["foo"] assert_identical(dv, self.dv) def test_array_interface(self): @@ -1588,18 +1750,17 @@ def test_array_interface(self): assert_array_equal(self.dv.clip(2, 3), self.v.clip(2, 3)) # test ufuncs expected = deepcopy(self.ds) - expected['foo'][:] = np.sin(self.x) - assert_equal(expected['foo'], np.sin(self.dv)) + expected["foo"][:] = np.sin(self.x) + assert_equal(expected["foo"], np.sin(self.dv)) assert_array_equal(self.dv, np.maximum(self.v, self.dv)) - bar = Variable(['x', 'y'], np.zeros((10, 20))) + bar = Variable(["x", "y"], np.zeros((10, 20))) assert_equal(self.dv, np.maximum(self.dv, bar)) def test_is_null(self): x = np.random.RandomState(42).randn(5, 6) x[x < 0] = np.nan - original = DataArray(x, [-np.arange(5), np.arange(6)], ['x', 'y']) - expected = DataArray(pd.isnull(x), [-np.arange(5), np.arange(6)], - ['x', 'y']) + original = DataArray(x, [-np.arange(5), np.arange(6)], ["x", "y"]) + expected = DataArray(pd.isnull(x), [-np.arange(5), np.arange(6)], ["x", "y"]) assert_identical(expected, original.isnull()) assert_identical(~expected, original.notnull()) @@ -1620,22 +1781,22 @@ def test_math(self): assert_equal(a, 0 * a + a) def test_math_automatic_alignment(self): - a = DataArray(range(5), [('x', range(5))]) - b = DataArray(range(5), [('x', range(1, 6))]) - expected = DataArray(np.ones(4), [('x', [1, 2, 3, 4])]) + a = DataArray(range(5), [("x", range(5))]) + b = DataArray(range(5), [("x", range(1, 6))]) + expected = DataArray(np.ones(4), [("x", [1, 2, 3, 4])]) assert_identical(a - b, expected) def test_non_overlapping_dataarrays_return_empty_result(self): - a = DataArray(range(5), [('x', range(5))]) + a = DataArray(range(5), [("x", range(5))]) result = a.isel(x=slice(2)) + a.isel(x=slice(2, None)) - assert len(result['x']) == 0 + assert len(result["x"]) == 0 def test_empty_dataarrays_return_empty_result(self): a = DataArray(data=[]) result = a * a - assert len(result['dim_0']) == 0 + assert len(result["dim_0"]) == 0 def test_inplace_math_basics(self): x = self.x @@ -1649,8 +1810,8 @@ def test_inplace_math_basics(self): assert source_ndarray(b.values) is x def test_inplace_math_automatic_alignment(self): - a = DataArray(range(5), [('x', range(5))]) - b = DataArray(range(1, 6), [('x', range(1, 6))]) + a = DataArray(range(5), [("x", range(5))]) + b = DataArray(range(1, 6), [("x", range(1, 6))]) with pytest.raises(xr.MergeError): a += b with pytest.raises(xr.MergeError): @@ -1662,20 +1823,23 @@ def test_math_name(self): # the other object has the same name or no name attribute and this # object isn't a coordinate; otherwise reset to None. a = self.dv - assert (+a).name == 'foo' - assert (a + 0).name == 'foo' + assert (+a).name == "foo" + assert (a + 0).name == "foo" assert (a + a.rename(None)).name is None - assert (a + a.rename('bar')).name is None - assert (a + a).name == 'foo' - assert (+a['x']).name == 'x' - assert (a['x'] + 0).name == 'x' - assert (a + a['x']).name is None + assert (a + a.rename("bar")).name is None + assert (a + a).name == "foo" + assert (+a["x"]).name == "x" + assert (a["x"] + 0).name == "x" + assert (a + a["x"]).name is None def test_math_with_coords(self): - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray(np.random.randn(2, 3), coords, dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray(np.random.randn(2, 3), coords, dims=["x", "y"]) actual = orig + 1 expected = DataArray(orig.values + 1, orig.coords) @@ -1685,22 +1849,22 @@ def test_math_with_coords(self): assert_identical(expected, actual) actual = orig + orig[0, 0] - exp_coords = {k: v for k, v in coords.items() if k != 'lat'} - expected = DataArray(orig.values + orig.values[0, 0], - exp_coords, dims=['x', 'y']) + exp_coords = {k: v for k, v in coords.items() if k != "lat"} + expected = DataArray( + orig.values + orig.values[0, 0], exp_coords, dims=["x", "y"] + ) assert_identical(expected, actual) actual = orig[0, 0] + orig assert_identical(expected, actual) actual = orig[0, 0] + orig[-1, -1] - expected = DataArray(orig.values[0, 0] + orig.values[-1, -1], - {'c': -999}) + expected = DataArray(orig.values[0, 0] + orig.values[-1, -1], {"c": -999}) assert_identical(expected, actual) actual = orig[:, 0] + orig[0, :] exp_values = orig[:, 0].values[:, None] + orig[0, :].values[None, :] - expected = DataArray(exp_values, exp_coords, dims=['x', 'y']) + expected = DataArray(exp_values, exp_coords, dims=["x", "y"]) assert_identical(expected, actual) actual = orig[0, :] + orig[:, 0] @@ -1713,25 +1877,25 @@ def test_math_with_coords(self): actual = orig.transpose(transpose_coords=True) - orig assert_identical(expected.transpose(transpose_coords=True), actual) - alt = DataArray([1, 1], {'x': [-1, -2], 'c': 'foo', 'd': 555}, 'x') + alt = DataArray([1, 1], {"x": [-1, -2], "c": "foo", "d": 555}, "x") actual = orig + alt expected = orig + 1 - expected.coords['d'] = 555 - del expected.coords['c'] + expected.coords["d"] = 555 + del expected.coords["c"] assert_identical(expected, actual) actual = alt + orig assert_identical(expected, actual) def test_index_math(self): - orig = DataArray(range(3), dims='x', name='x') + orig = DataArray(range(3), dims="x", name="x") actual = orig + 1 - expected = DataArray(1 + np.arange(3), dims='x', name='x') + expected = DataArray(1 + np.arange(3), dims="x", name="x") assert_identical(expected, actual) # regression tests for #254 actual = orig[0] < orig - expected = DataArray([False, True, True], dims='x', name='x') + expected = DataArray([False, True, True], dims="x", name="x") assert_identical(expected, actual) actual = orig > orig[0] @@ -1739,125 +1903,147 @@ def test_index_math(self): def test_dataset_math(self): # more comprehensive tests with multiple dataset variables - obs = Dataset({'tmin': ('x', np.arange(5)), - 'tmax': ('x', 10 + np.arange(5))}, - {'x': ('x', 0.5 * np.arange(5)), - 'loc': ('x', range(-2, 3))}) + obs = Dataset( + {"tmin": ("x", np.arange(5)), "tmax": ("x", 10 + np.arange(5))}, + {"x": ("x", 0.5 * np.arange(5)), "loc": ("x", range(-2, 3))}, + ) - actual = 2 * obs['tmax'] - expected = DataArray(2 * (10 + np.arange(5)), obs.coords, name='tmax') + actual = 2 * obs["tmax"] + expected = DataArray(2 * (10 + np.arange(5)), obs.coords, name="tmax") assert_identical(actual, expected) - actual = obs['tmax'] - obs['tmin'] + actual = obs["tmax"] - obs["tmin"] expected = DataArray(10 * np.ones(5), obs.coords) assert_identical(actual, expected) - sim = Dataset({'tmin': ('x', 1 + np.arange(5)), - 'tmax': ('x', 11 + np.arange(5)), - # does *not* include 'loc' as a coordinate - 'x': ('x', 0.5 * np.arange(5))}) + sim = Dataset( + { + "tmin": ("x", 1 + np.arange(5)), + "tmax": ("x", 11 + np.arange(5)), + # does *not* include 'loc' as a coordinate + "x": ("x", 0.5 * np.arange(5)), + } + ) - actual = sim['tmin'] - obs['tmin'] - expected = DataArray(np.ones(5), obs.coords, name='tmin') + actual = sim["tmin"] - obs["tmin"] + expected = DataArray(np.ones(5), obs.coords, name="tmin") assert_identical(actual, expected) - actual = -obs['tmin'] + sim['tmin'] + actual = -obs["tmin"] + sim["tmin"] assert_identical(actual, expected) - actual = sim['tmin'].copy() - actual -= obs['tmin'] + actual = sim["tmin"].copy() + actual -= obs["tmin"] assert_identical(actual, expected) actual = sim.copy() - actual['tmin'] = sim['tmin'] - obs['tmin'] - expected = Dataset({'tmin': ('x', np.ones(5)), - 'tmax': ('x', sim['tmax'].values)}, - obs.coords) + actual["tmin"] = sim["tmin"] - obs["tmin"] + expected = Dataset( + {"tmin": ("x", np.ones(5)), "tmax": ("x", sim["tmax"].values)}, obs.coords + ) assert_identical(actual, expected) actual = sim.copy() - actual['tmin'] -= obs['tmin'] + actual["tmin"] -= obs["tmin"] assert_identical(actual, expected) def test_stack_unstack(self): - orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2}) + orig = DataArray([[0, 1], [2, 3]], dims=["x", "y"], attrs={"foo": 2}) assert_identical(orig, orig.unstack()) # test GH3000 - a = orig[:0, :1].stack(dim=('x', 'y')).dim.to_index() - if pd.__version__ < '0.24.0': - b = pd.MultiIndex(levels=[pd.Int64Index([]), pd.Int64Index([0])], - labels=[[], []], names=['x', 'y']) + a = orig[:0, :1].stack(dim=("x", "y")).dim.to_index() + if pd.__version__ < "0.24.0": + b = pd.MultiIndex( + levels=[pd.Int64Index([]), pd.Int64Index([0])], + labels=[[], []], + names=["x", "y"], + ) else: - b = pd.MultiIndex(levels=[pd.Int64Index([]), pd.Int64Index([0])], - codes=[[], []], names=['x', 'y']) + b = pd.MultiIndex( + levels=[pd.Int64Index([]), pd.Int64Index([0])], + codes=[[], []], + names=["x", "y"], + ) pd.util.testing.assert_index_equal(a, b) - actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) + actual = orig.stack(z=["x", "y"]).unstack("z").drop(["x", "y"]) assert_identical(orig, actual) - dims = ['a', 'b', 'c', 'd', 'e'] + dims = ["a", "b", "c", "d", "e"] orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) - stacked = orig.stack(ab=['a', 'b'], cd=['c', 'd']) + stacked = orig.stack(ab=["a", "b"], cd=["c", "d"]) - unstacked = stacked.unstack(['ab', 'cd']) - roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + unstacked = stacked.unstack(["ab", "cd"]) + roundtripped = unstacked.drop(["a", "b", "c", "d"]).transpose(*dims) assert_identical(orig, roundtripped) unstacked = stacked.unstack() - roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + roundtripped = unstacked.drop(["a", "b", "c", "d"]).transpose(*dims) assert_identical(orig, roundtripped) def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 - orig = DataArray(np.random.rand(3, 4), dims=('y', 'x'), - coords={'x': np.arange(4), - 'y': np.arange(3, 0, -1)}) - stacked = orig.stack(allpoints=['y', 'x']) - actual = stacked.unstack('allpoints') + orig = DataArray( + np.random.rand(3, 4), + dims=("y", "x"), + coords={"x": np.arange(4), "y": np.arange(3, 0, -1)}, + ) + stacked = orig.stack(allpoints=["y", "x"]) + actual = stacked.unstack("allpoints") assert_identical(orig, actual) def test_unstack_pandas_consistency(self): - df = pd.DataFrame({'foo': range(3), - 'x': ['a', 'b', 'b'], - 'y': [0, 0, 1]}) - s = df.set_index(['x', 'y'])['foo'] - expected = DataArray(s.unstack(), name='foo') - actual = DataArray(s, dims='z').unstack('z') + df = pd.DataFrame({"foo": range(3), "x": ["a", "b", "b"], "y": [0, 0, 1]}) + s = df.set_index(["x", "y"])["foo"] + expected = DataArray(s.unstack(), name="foo") + actual = DataArray(s, dims="z").unstack("z") assert_identical(expected, actual) def test_stack_nonunique_consistency(self): - orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], - coords={'x': [0, 1], 'y': [0, 0]}) - actual = orig.stack(z=['x', 'y']) - expected = DataArray(orig.to_pandas().stack(), dims='z') + orig = DataArray( + [[0, 1], [2, 3]], dims=["x", "y"], coords={"x": [0, 1], "y": [0, 0]} + ) + actual = orig.stack(z=["x", "y"]) + expected = DataArray(orig.to_pandas().stack(), dims="z") assert_identical(expected, actual) def test_to_unstacked_dataset_raises_value_error(self): - data = DataArray([0, 1], dims='x', coords={'x': [0, 1]}) - with pytest.raises( - ValueError, match="'x' is not a stacked coordinate"): - data.to_unstacked_dataset('x', 0) + data = DataArray([0, 1], dims="x", coords={"x": [0, 1]}) + with pytest.raises(ValueError, match="'x' is not a stacked coordinate"): + data.to_unstacked_dataset("x", 0) def test_transpose(self): - da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'), - coords={'x': range(3), 'y': range(4), 'z': range(5), - 'xy': (('x', 'y'), np.random.randn(3, 4))}) + da = DataArray( + np.random.randn(3, 4, 5), + dims=("x", "y", "z"), + coords={ + "x": range(3), + "y": range(4), + "z": range(5), + "xy": (("x", "y"), np.random.randn(3, 4)), + }, + ) actual = da.transpose(transpose_coords=False) - expected = DataArray(da.values.T, dims=('z', 'y', 'x'), - coords=da.coords) + expected = DataArray(da.values.T, dims=("z", "y", "x"), coords=da.coords) assert_equal(expected, actual) - actual = da.transpose('z', 'y', 'x', transpose_coords=True) - expected = DataArray(da.values.T, dims=('z', 'y', 'x'), - coords={'x': da.x.values, 'y': da.y.values, - 'z': da.z.values, - 'xy': (('y', 'x'), da.xy.values.T)}) + actual = da.transpose("z", "y", "x", transpose_coords=True) + expected = DataArray( + da.values.T, + dims=("z", "y", "x"), + coords={ + "x": da.x.values, + "y": da.y.values, + "z": da.z.values, + "xy": (("y", "x"), da.xy.values.T), + }, + ) assert_equal(expected, actual) with pytest.raises(ValueError): - da.transpose('x', 'y') + da.transpose("x", "y") with pytest.warns(FutureWarning): da.transpose() @@ -1866,204 +2052,209 @@ def test_squeeze(self): assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable) def test_squeeze_drop(self): - array = DataArray([1], [('x', [0])]) + array = DataArray([1], [("x", [0])]) expected = DataArray(1) actual = array.squeeze(drop=True) assert_identical(expected, actual) - expected = DataArray(1, {'x': 0}) + expected = DataArray(1, {"x": 0}) actual = array.squeeze(drop=False) assert_identical(expected, actual) - array = DataArray([[[0., 1.]]], dims=['dim_0', 'dim_1', 'dim_2']) - expected = DataArray([[0., 1.]], dims=['dim_1', 'dim_2']) + array = DataArray([[[0.0, 1.0]]], dims=["dim_0", "dim_1", "dim_2"]) + expected = DataArray([[0.0, 1.0]], dims=["dim_1", "dim_2"]) actual = array.squeeze(axis=0) assert_identical(expected, actual) - array = DataArray([[[[0., 1.]]]], dims=[ - 'dim_0', 'dim_1', 'dim_2', 'dim_3']) - expected = DataArray([[0., 1.]], dims=['dim_1', 'dim_3']) + array = DataArray([[[[0.0, 1.0]]]], dims=["dim_0", "dim_1", "dim_2", "dim_3"]) + expected = DataArray([[0.0, 1.0]], dims=["dim_1", "dim_3"]) actual = array.squeeze(axis=(0, 2)) assert_identical(expected, actual) - array = DataArray([[[0., 1.]]], dims=['dim_0', 'dim_1', 'dim_2']) + array = DataArray([[[0.0, 1.0]]], dims=["dim_0", "dim_1", "dim_2"]) with pytest.raises(ValueError): - array.squeeze(axis=0, dim='dim_1') + array.squeeze(axis=0, dim="dim_1") def test_drop_coordinates(self): - expected = DataArray(np.random.randn(2, 3), dims=['x', 'y']) + expected = DataArray(np.random.randn(2, 3), dims=["x", "y"]) arr = expected.copy() - arr.coords['z'] = 2 - actual = arr.drop('z') + arr.coords["z"] = 2 + actual = arr.drop("z") assert_identical(expected, actual) with pytest.raises(ValueError): - arr.drop('not found') + arr.drop("not found") - actual = expected.drop('not found', errors='ignore') + actual = expected.drop("not found", errors="ignore") assert_identical(actual, expected) - with raises_regex(ValueError, 'cannot be found'): - arr.drop('w') + with raises_regex(ValueError, "cannot be found"): + arr.drop("w") - actual = expected.drop('w', errors='ignore') + actual = expected.drop("w", errors="ignore") assert_identical(actual, expected) - renamed = arr.rename('foo') - with raises_regex(ValueError, 'cannot be found'): - renamed.drop('foo') + renamed = arr.rename("foo") + with raises_regex(ValueError, "cannot be found"): + renamed.drop("foo") - actual = renamed.drop('foo', errors='ignore') + actual = renamed.drop("foo", errors="ignore") assert_identical(actual, renamed) def test_drop_index_labels(self): - arr = DataArray(np.random.randn(2, 3), coords={'y': [0, 1, 2]}, - dims=['x', 'y']) - actual = arr.drop([0, 1], dim='y') + arr = DataArray(np.random.randn(2, 3), coords={"y": [0, 1, 2]}, dims=["x", "y"]) + actual = arr.drop([0, 1], dim="y") expected = arr[:, 2:] assert_identical(actual, expected) - with raises_regex((KeyError, ValueError), 'not .* in axis'): - actual = arr.drop([0, 1, 3], dim='y') + with raises_regex((KeyError, ValueError), "not .* in axis"): + actual = arr.drop([0, 1, 3], dim="y") - actual = arr.drop([0, 1, 3], dim='y', errors='ignore') + actual = arr.drop([0, 1, 3], dim="y", errors="ignore") assert_identical(actual, expected) def test_dropna(self): x = np.random.randn(4, 4) x[::2, 0] = np.nan - arr = DataArray(x, dims=['a', 'b']) + arr = DataArray(x, dims=["a", "b"]) - actual = arr.dropna('a') + actual = arr.dropna("a") expected = arr[1::2] assert_identical(actual, expected) - actual = arr.dropna('b', how='all') + actual = arr.dropna("b", how="all") assert_identical(actual, arr) - actual = arr.dropna('a', thresh=1) + actual = arr.dropna("a", thresh=1) assert_identical(actual, arr) - actual = arr.dropna('b', thresh=3) + actual = arr.dropna("b", thresh=3) expected = arr[:, 1:] assert_identical(actual, expected) def test_where(self): - arr = DataArray(np.arange(4), dims='x') + arr = DataArray(np.arange(4), dims="x") expected = arr.sel(x=slice(2)) actual = arr.where(arr.x < 2, drop=True) assert_identical(actual, expected) def test_where_string(self): - array = DataArray(['a', 'b']) - expected = DataArray(np.array(['a', np.nan], dtype=object)) + array = DataArray(["a", "b"]) + expected = DataArray(np.array(["a", np.nan], dtype=object)) actual = array.where([True, False]) assert_identical(actual, expected) def test_cumops(self): - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, - dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) actual = orig.cumsum() - expected = DataArray([[-1, -1, 0], [-4, -4, 0]], coords, - dims=['x', 'y']) + expected = DataArray([[-1, -1, 0], [-4, -4, 0]], coords, dims=["x", "y"]) assert_identical(expected, actual) - actual = orig.cumsum('x') - expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, - dims=['x', 'y']) + actual = orig.cumsum("x") + expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, dims=["x", "y"]) assert_identical(expected, actual) - actual = orig.cumsum('y') - expected = DataArray([[-1, -1, 0], [-3, -3, 0]], coords, - dims=['x', 'y']) + actual = orig.cumsum("y") + expected = DataArray([[-1, -1, 0], [-3, -3, 0]], coords, dims=["x", "y"]) assert_identical(expected, actual) - actual = orig.cumprod('x') - expected = DataArray([[-1, 0, 1], [3, 0, 3]], coords, - dims=['x', 'y']) + actual = orig.cumprod("x") + expected = DataArray([[-1, 0, 1], [3, 0, 3]], coords, dims=["x", "y"]) assert_identical(expected, actual) - actual = orig.cumprod('y') - expected = DataArray([[-1, 0, 0], [-3, 0, 0]], coords, dims=['x', 'y']) + actual = orig.cumprod("y") + expected = DataArray([[-1, 0, 0], [-3, 0, 0]], coords, dims=["x", "y"]) assert_identical(expected, actual) def test_reduce(self): - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) actual = orig.mean() - expected = DataArray(0, {'c': -999}) + expected = DataArray(0, {"c": -999}) assert_identical(expected, actual) - actual = orig.mean(['x', 'y']) + actual = orig.mean(["x", "y"]) assert_identical(expected, actual) - actual = orig.mean('x') - expected = DataArray([-2, 0, 2], {'y': coords['y'], 'c': -999}, 'y') + actual = orig.mean("x") + expected = DataArray([-2, 0, 2], {"y": coords["y"], "c": -999}, "y") assert_identical(expected, actual) - actual = orig.mean(['x']) + actual = orig.mean(["x"]) assert_identical(expected, actual) - actual = orig.mean('y') - expected = DataArray([0, 0], {'x': coords['x'], 'c': -999}, 'x') + actual = orig.mean("y") + expected = DataArray([0, 0], {"x": coords["x"], "c": -999}, "x") assert_identical(expected, actual) - assert_equal(self.dv.reduce(np.mean, 'x').variable, - self.v.reduce(np.mean, 'x')) + assert_equal(self.dv.reduce(np.mean, "x").variable, self.v.reduce(np.mean, "x")) - orig = DataArray([[1, 0, np.nan], [3, 0, 3]], coords, dims=['x', 'y']) + orig = DataArray([[1, 0, np.nan], [3, 0, 3]], coords, dims=["x", "y"]) actual = orig.count() - expected = DataArray(5, {'c': -999}) + expected = DataArray(5, {"c": -999}) assert_identical(expected, actual) # uint support - orig = DataArray(np.arange(6).reshape(3, 2).astype('uint'), - dims=['x', 'y']) - assert orig.dtype.kind == 'u' - actual = orig.mean(dim='x', skipna=True) - expected = DataArray(orig.values.astype(int), - dims=['x', 'y']).mean('x') + orig = DataArray(np.arange(6).reshape(3, 2).astype("uint"), dims=["x", "y"]) + assert orig.dtype.kind == "u" + actual = orig.mean(dim="x", skipna=True) + expected = DataArray(orig.values.astype(int), dims=["x", "y"]).mean("x") assert_equal(actual, expected) def test_reduce_keepdims(self): - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) # Mean on all axes loses non-constant coordinates actual = orig.mean(keepdims=True) - expected = DataArray(orig.data.mean(keepdims=True), dims=orig.dims, - coords={k: v for k, v in coords.items() - if k in ['c']}) + expected = DataArray( + orig.data.mean(keepdims=True), + dims=orig.dims, + coords={k: v for k, v in coords.items() if k in ["c"]}, + ) assert_equal(actual, expected) - assert actual.sizes['x'] == 1 - assert actual.sizes['y'] == 1 + assert actual.sizes["x"] == 1 + assert actual.sizes["y"] == 1 # Mean on specific axes loses coordinates not involving that axis - actual = orig.mean('y', keepdims=True) - expected = DataArray(orig.data.mean(axis=1, keepdims=True), - dims=orig.dims, - coords={k: v for k, v in coords.items() - if k not in ['y', 'lat']}) + actual = orig.mean("y", keepdims=True) + expected = DataArray( + orig.data.mean(axis=1, keepdims=True), + dims=orig.dims, + coords={k: v for k, v in coords.items() if k not in ["y", "lat"]}, + ) assert_equal(actual, expected) @requires_bottleneck def test_reduce_keepdims_bottleneck(self): import bottleneck - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) # Bottleneck does not have its own keepdims implementation actual = orig.reduce(bottleneck.nanmean, keepdims=True) @@ -2071,19 +2262,25 @@ def test_reduce_keepdims_bottleneck(self): assert_equal(actual, expected) def test_reduce_dtype(self): - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) for dtype in [np.float16, np.float32, np.float64]: assert orig.astype(float).mean(dtype=dtype).dtype == dtype def test_reduce_out(self): - coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], - 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), - 'c': -999} - orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) with pytest.raises(TypeError): orig.mean(out=np.ones(orig.shape)) @@ -2091,11 +2288,13 @@ def test_reduce_out(self): # skip due to bug in older versions of numpy.nanpercentile def test_quantile(self): for q in [0.25, [0.50], [0.25, 0.75]]: - for axis, dim in zip([None, 0, [0], [0, 1]], - [None, 'x', ['x'], ['x', 'y']]): + for axis, dim in zip( + [None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]] + ): actual = self.dv.quantile(q, dim=dim) - expected = np.nanpercentile(self.dv.values, np.array(q) * 100, - axis=axis) + expected = np.nanpercentile( + self.dv.values, np.array(q) * 100, axis=axis + ) np.testing.assert_allclose(actual.values, expected) def test_reduce_keep_attrs(self): @@ -2111,25 +2310,25 @@ def test_reduce_keep_attrs(self): def test_assign_attrs(self): expected = DataArray([], attrs=dict(a=1, b=2)) - expected.attrs['a'] = 1 - expected.attrs['b'] = 2 + expected.attrs["a"] = 1 + expected.attrs["b"] = 2 new = DataArray([]) actual = DataArray([]).assign_attrs(a=1, b=2) assert_identical(actual, expected) assert new.attrs == {} - expected.attrs['c'] = 3 - new_actual = actual.assign_attrs({'c': 3}) + expected.attrs["c"] = 3 + new_actual = actual.assign_attrs({"c": 3}) assert_identical(new_actual, expected) - assert actual.attrs == {'a': 1, 'b': 2} + assert actual.attrs == {"a": 1, "b": 2} def test_fillna(self): - a = DataArray([np.nan, 1, np.nan, 3], coords={'x': range(4)}, dims='x') + a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) - expected = DataArray([-1, 1, -1, 3], coords={'x': range(4)}, dims='x') + expected = DataArray([-1, 1, -1, 3], coords={"x": range(4)}, dims="x") assert_identical(expected, actual) - b = DataArray(range(4), coords={'x': range(4)}, dims='x') + b = DataArray(range(4), coords={"x": range(4)}, dims="x") actual = a.fillna(b) expected = b.copy() assert_identical(expected, actual) @@ -2143,41 +2342,43 @@ def test_fillna(self): actual = a.fillna(b[:0]) assert_identical(a, actual) - with raises_regex(TypeError, 'fillna on a DataArray'): + with raises_regex(TypeError, "fillna on a DataArray"): a.fillna({0: 0}) - with raises_regex(ValueError, 'broadcast'): + with raises_regex(ValueError, "broadcast"): a.fillna([1, 2]) - fill_value = DataArray([0, 1], dims='y') + fill_value = DataArray([0, 1], dims="y") actual = a.fillna(fill_value) - expected = DataArray([[0, 1], [1, 1], [0, 1], [3, 3]], - coords={'x': range(4)}, dims=('x', 'y')) + expected = DataArray( + [[0, 1], [1, 1], [0, 1], [3, 3]], coords={"x": range(4)}, dims=("x", "y") + ) assert_identical(expected, actual) expected = b.copy() for target in [a, expected]: - target.coords['b'] = ('x', [0, 0, 1, 1]) - actual = a.groupby('b').fillna(DataArray([0, 2], dims='b')) + target.coords["b"] = ("x", [0, 0, 1, 1]) + actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) assert_identical(expected, actual) def test_groupby_iter(self): - for ((act_x, act_dv), (exp_x, exp_ds)) in \ - zip(self.dv.groupby('y'), self.ds.groupby('y')): + for ((act_x, act_dv), (exp_x, exp_ds)) in zip( + self.dv.groupby("y"), self.ds.groupby("y") + ): assert exp_x == act_x - assert_identical(exp_ds['foo'], act_dv) - for ((_, exp_dv), act_dv) in zip(self.dv.groupby('x'), self.dv): + assert_identical(exp_ds["foo"], act_dv) + for ((_, exp_dv), act_dv) in zip(self.dv.groupby("x"), self.dv): assert_identical(exp_dv, act_dv) def make_groupby_example_array(self): da = self.dv.copy() - da.coords['abc'] = ('y', np.array(['a'] * 9 + ['c'] + ['b'] * 10)) - da.coords['y'] = 20 + 100 * da['y'] + da.coords["abc"] = ("y", np.array(["a"] * 9 + ["c"] + ["b"] * 10)) + da.coords["y"] = 20 + 100 * da["y"] return da def test_groupby_properties(self): - grouped = self.make_groupby_example_array().groupby('abc') - expected_groups = {'a': range(0, 9), 'c': [9], 'b': range(10, 20)} + grouped = self.make_groupby_example_array().groupby("abc") + expected_groups = {"a": range(0, 9), "c": [9], "b": range(10, 20)} assert expected_groups.keys() == grouped.groups.keys() for key in expected_groups: assert_array_equal(expected_groups[key], grouped.groups[key]) @@ -2185,12 +2386,12 @@ def test_groupby_properties(self): def test_groupby_apply_identity(self): expected = self.make_groupby_example_array() - idx = expected.coords['y'] + idx = expected.coords["y"] def identity(x): return x - for g in ['x', 'y', 'abc', idx]: + for g in ["x", "y", "abc", idx]: for shortcut in [False, True]: for squeeze in [False, True]: grouped = expected.groupby(g, squeeze=squeeze) @@ -2199,75 +2400,113 @@ def identity(x): def test_groupby_sum(self): array = self.make_groupby_example_array() - grouped = array.groupby('abc') + grouped = array.groupby("abc") expected_sum_all = Dataset( - {'foo': Variable(['abc'], np.array([self.x[:, :9].sum(), - self.x[:, 10:].sum(), - self.x[:, 9:10].sum()]).T), - 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] + { + "foo": Variable( + ["abc"], + np.array( + [ + self.x[:, :9].sum(), + self.x[:, 10:].sum(), + self.x[:, 9:10].sum(), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=ALL_DIMS)) assert_allclose(expected_sum_all, grouped.sum(ALL_DIMS)) - expected = DataArray([array['y'].values[idx].sum() for idx - in [slice(9), slice(10, None), slice(9, 10)]], - [['a', 'b', 'c']], ['abc']) - actual = array['y'].groupby('abc').apply(np.sum) + expected = DataArray( + [ + array["y"].values[idx].sum() + for idx in [slice(9), slice(10, None), slice(9, 10)] + ], + [["a", "b", "c"]], + ["abc"], + ) + actual = array["y"].groupby("abc").apply(np.sum) assert_allclose(expected, actual) - actual = array['y'].groupby('abc').sum(ALL_DIMS) + actual = array["y"].groupby("abc").sum(ALL_DIMS) assert_allclose(expected, actual) expected_sum_axis1 = Dataset( - {'foo': (['x', 'abc'], np.array([self.x[:, :9].sum(1), - self.x[:, 10:].sum(1), - self.x[:, 9:10].sum(1)]).T), - 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] - assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, 'y')) - assert_allclose(expected_sum_axis1, grouped.sum('y')) + { + "foo": ( + ["x", "abc"], + np.array( + [ + self.x[:, :9].sum(1), + self.x[:, 10:].sum(1), + self.x[:, 9:10].sum(1), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) + assert_allclose(expected_sum_axis1, grouped.sum("y")) def test_groupby_warning(self): array = self.make_groupby_example_array() - grouped = array.groupby('y') + grouped = array.groupby("y") with pytest.warns(FutureWarning): grouped.sum() - @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.13'), - reason="not to forget the behavior change") + @pytest.mark.skipif( + LooseVersion(xr.__version__) < LooseVersion("0.13"), + reason="not to forget the behavior change", + ) def test_groupby_sum_default(self): array = self.make_groupby_example_array() - grouped = array.groupby('abc') + grouped = array.groupby("abc") expected_sum_all = Dataset( - {'foo': Variable(['x', 'abc'], - np.array([self.x[:, :9].sum(axis=-1), - self.x[:, 10:].sum(axis=-1), - self.x[:, 9:10].sum(axis=-1)]).T), - 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] + { + "foo": Variable( + ["x", "abc"], + np.array( + [ + self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] assert_allclose(expected_sum_all, grouped.sum()) def test_groupby_count(self): array = DataArray( [0, 0, np.nan, np.nan, 0, 0], - coords={'cat': ('x', ['a', 'b', 'b', 'c', 'c', 'c'])}, - dims='x') - actual = array.groupby('cat').count() - expected = DataArray([1, 1, 2], coords=[('cat', ['a', 'b', 'c'])]) + coords={"cat": ("x", ["a", "b", "b", "c", "c", "c"])}, + dims="x", + ) + actual = array.groupby("cat").count() + expected = DataArray([1, 1, 2], coords=[("cat", ["a", "b", "c"])]) assert_identical(actual, expected) - @pytest.mark.skip('needs to be fixed for shortcut=False, keep_attrs=False') + @pytest.mark.skip("needs to be fixed for shortcut=False, keep_attrs=False") def test_groupby_reduce_attrs(self): array = self.make_groupby_example_array() - array.attrs['foo'] = 'bar' + array.attrs["foo"] = "bar" for shortcut in [True, False]: for keep_attrs in [True, False]: - print('shortcut=%s, keep_attrs=%s' % (shortcut, keep_attrs)) - actual = array.groupby('abc').reduce( - np.mean, keep_attrs=keep_attrs, shortcut=shortcut) - expected = array.groupby('abc').mean() + print("shortcut=%s, keep_attrs=%s" % (shortcut, keep_attrs)) + actual = array.groupby("abc").reduce( + np.mean, keep_attrs=keep_attrs, shortcut=shortcut + ) + expected = array.groupby("abc").mean() if keep_attrs: - expected.attrs['foo'] = 'bar' + expected.attrs["foo"] = "bar" assert_identical(expected, actual) def test_groupby_apply_center(self): @@ -2275,31 +2514,31 @@ def center(x): return x - np.mean(x) array = self.make_groupby_example_array() - grouped = array.groupby('abc') + grouped = array.groupby("abc") expected_ds = array.to_dataset() - exp_data = np.hstack([center(self.x[:, :9]), - center(self.x[:, 9:10]), - center(self.x[:, 10:])]) - expected_ds['foo'] = (['x', 'y'], exp_data) - expected_centered = expected_ds['foo'] + exp_data = np.hstack( + [center(self.x[:, :9]), center(self.x[:, 9:10]), center(self.x[:, 10:])] + ) + expected_ds["foo"] = (["x", "y"], exp_data) + expected_centered = expected_ds["foo"] assert_allclose(expected_centered, grouped.apply(center)) def test_groupby_apply_ndarray(self): # regression test for #326 array = self.make_groupby_example_array() - grouped = array.groupby('abc') + grouped = array.groupby("abc") actual = grouped.apply(np.asarray) assert_equal(array, actual) def test_groupby_apply_changes_metadata(self): def change_metadata(x): - x.coords['x'] = x.coords['x'] * 2 - x.attrs['fruit'] = 'lemon' + x.coords["x"] = x.coords["x"] * 2 + x.attrs["fruit"] = "lemon" return x array = self.make_groupby_example_array() - grouped = array.groupby('abc') + grouped = array.groupby("abc") actual = grouped.apply(change_metadata) expected = array.copy() expected = change_metadata(expected) @@ -2308,16 +2547,16 @@ def change_metadata(x): def test_groupby_math(self): array = self.make_groupby_example_array() for squeeze in [True, False]: - grouped = array.groupby('x', squeeze=squeeze) + grouped = array.groupby("x", squeeze=squeeze) - expected = array + array.coords['x'] - actual = grouped + array.coords['x'] + expected = array + array.coords["x"] + actual = grouped + array.coords["x"] assert_identical(expected, actual) - actual = array.coords['x'] + grouped + actual = array.coords["x"] + grouped assert_identical(expected, actual) - ds = array.coords['x'].to_dataset(name='X') + ds = array.coords["x"].to_dataset(name="X") expected = array + ds actual = grouped + ds assert_identical(expected, actual) @@ -2325,136 +2564,153 @@ def test_groupby_math(self): actual = ds + grouped assert_identical(expected, actual) - grouped = array.groupby('abc') + grouped = array.groupby("abc") expected_agg = (grouped.mean(ALL_DIMS) - np.arange(3)).rename(None) - actual = grouped - DataArray(range(3), [('abc', ['a', 'b', 'c'])]) - actual_agg = actual.groupby('abc').mean(ALL_DIMS) + actual = grouped - DataArray(range(3), [("abc", ["a", "b", "c"])]) + actual_agg = actual.groupby("abc").mean(ALL_DIMS) assert_allclose(expected_agg, actual_agg) - with raises_regex(TypeError, 'only support binary ops'): + with raises_regex(TypeError, "only support binary ops"): grouped + 1 - with raises_regex(TypeError, 'only support binary ops'): + with raises_regex(TypeError, "only support binary ops"): grouped + grouped - with raises_regex(TypeError, 'in-place operations'): + with raises_regex(TypeError, "in-place operations"): array += grouped def test_groupby_math_not_aligned(self): - array = DataArray(range(4), {'b': ('x', [0, 0, 1, 1]), - 'x': [0, 1, 2, 3]}, - dims='x') - other = DataArray([10], coords={'b': [0]}, dims='b') - actual = array.groupby('b') + other + array = DataArray( + range(4), {"b": ("x", [0, 0, 1, 1]), "x": [0, 1, 2, 3]}, dims="x" + ) + other = DataArray([10], coords={"b": [0]}, dims="b") + actual = array.groupby("b") + other expected = DataArray([10, 11, np.nan, np.nan], array.coords) assert_identical(expected, actual) - other = DataArray([10], coords={'c': 123, 'b': [0]}, dims='b') - actual = array.groupby('b') + other - expected.coords['c'] = (['x'], [123] * 2 + [np.nan] * 2) + other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b") + actual = array.groupby("b") + other + expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2) assert_identical(expected, actual) - other = Dataset({'a': ('b', [10])}, {'b': [0]}) - actual = array.groupby('b') + other - expected = Dataset({'a': ('x', [10, 11, np.nan, np.nan])}, - array.coords) + other = Dataset({"a": ("b", [10])}, {"b": [0]}) + actual = array.groupby("b") + other + expected = Dataset({"a": ("x", [10, 11, np.nan, np.nan])}, array.coords) assert_identical(expected, actual) def test_groupby_restore_dim_order(self): - array = DataArray(np.random.randn(5, 3), - coords={'a': ('x', range(5)), 'b': ('y', range(3))}, - dims=['x', 'y']) - for by, expected_dims in [('x', ('x', 'y')), - ('y', ('x', 'y')), - ('a', ('a', 'y')), - ('b', ('x', 'b'))]: + array = DataArray( + np.random.randn(5, 3), + coords={"a": ("x", range(5)), "b": ("y", range(3))}, + dims=["x", "y"], + ) + for by, expected_dims in [ + ("x", ("x", "y")), + ("y", ("x", "y")), + ("a", ("a", "y")), + ("b", ("x", "b")), + ]: result = array.groupby(by).apply(lambda x: x.squeeze()) assert result.dims == expected_dims def test_groupby_restore_coord_dims(self): - array = DataArray(np.random.randn(5, 3), - coords={'a': ('x', range(5)), 'b': ('y', range(3)), - 'c': (('x', 'y'), np.random.randn(5, 3))}, - dims=['x', 'y']) - - for by, expected_dims in [('x', ('x', 'y')), - ('y', ('x', 'y')), - ('a', ('a', 'y')), - ('b', ('x', 'b'))]: + array = DataArray( + np.random.randn(5, 3), + coords={ + "a": ("x", range(5)), + "b": ("y", range(3)), + "c": (("x", "y"), np.random.randn(5, 3)), + }, + dims=["x", "y"], + ) + + for by, expected_dims in [ + ("x", ("x", "y")), + ("y", ("x", "y")), + ("a", ("a", "y")), + ("b", ("x", "b")), + ]: result = array.groupby(by, restore_coord_dims=True).apply( - lambda x: x.squeeze())['c'] + lambda x: x.squeeze() + )["c"] assert result.dims == expected_dims with pytest.warns(FutureWarning): - array.groupby('x').apply(lambda x: x.squeeze()) + array.groupby("x").apply(lambda x: x.squeeze()) def test_groupby_first_and_last(self): - array = DataArray([1, 2, 3, 4, 5], dims='x') - by = DataArray(['a'] * 2 + ['b'] * 3, dims='x', name='ab') + array = DataArray([1, 2, 3, 4, 5], dims="x") + by = DataArray(["a"] * 2 + ["b"] * 3, dims="x", name="ab") - expected = DataArray([1, 3], [('ab', ['a', 'b'])]) + expected = DataArray([1, 3], [("ab", ["a", "b"])]) actual = array.groupby(by).first() assert_identical(expected, actual) - expected = DataArray([2, 5], [('ab', ['a', 'b'])]) + expected = DataArray([2, 5], [("ab", ["a", "b"])]) actual = array.groupby(by).last() assert_identical(expected, actual) - array = DataArray(np.random.randn(5, 3), dims=['x', 'y']) - expected = DataArray(array[[0, 2]], {'ab': ['a', 'b']}, ['ab', 'y']) + array = DataArray(np.random.randn(5, 3), dims=["x", "y"]) + expected = DataArray(array[[0, 2]], {"ab": ["a", "b"]}, ["ab", "y"]) actual = array.groupby(by).first() assert_identical(expected, actual) - actual = array.groupby('x').first() + actual = array.groupby("x").first() expected = array # should be a no-op assert_identical(expected, actual) def make_groupby_multidim_example_array(self): - return DataArray([[[0, 1], [2, 3]], [[5, 10], [15, 20]]], - coords={'lon': (['ny', 'nx'], [[30, 40], [40, 50]]), - 'lat': (['ny', 'nx'], [[10, 10], [20, 20]])}, - dims=['time', 'ny', 'nx']) + return DataArray( + [[[0, 1], [2, 3]], [[5, 10], [15, 20]]], + coords={ + "lon": (["ny", "nx"], [[30, 40], [40, 50]]), + "lat": (["ny", "nx"], [[10, 10], [20, 20]]), + }, + dims=["time", "ny", "nx"], + ) def test_groupby_multidim(self): array = self.make_groupby_multidim_example_array() for dim, expected_sum in [ - ('lon', DataArray([5, 28, 23], - coords=[('lon', [30., 40., 50.])])), - ('lat', DataArray([16, 40], coords=[('lat', [10., 20.])]))]: + ("lon", DataArray([5, 28, 23], coords=[("lon", [30.0, 40.0, 50.0])])), + ("lat", DataArray([16, 40], coords=[("lat", [10.0, 20.0])])), + ]: actual_sum = array.groupby(dim).sum(ALL_DIMS) assert_identical(expected_sum, actual_sum) def test_groupby_multidim_apply(self): array = self.make_groupby_multidim_example_array() - actual = array.groupby('lon').apply(lambda x: x - x.mean()) - expected = DataArray([[[-2.5, -6.], [-5., -8.5]], - [[2.5, 3.], [8., 8.5]]], - coords=array.coords, dims=array.dims) + actual = array.groupby("lon").apply(lambda x: x - x.mean()) + expected = DataArray( + [[[-2.5, -6.0], [-5.0, -8.5]], [[2.5, 3.0], [8.0, 8.5]]], + coords=array.coords, + dims=array.dims, + ) assert_identical(expected, actual) def test_groupby_bins(self): - array = DataArray(np.arange(4), dims='dim_0') + array = DataArray(np.arange(4), dims="dim_0") # the first value should not be part of any group ("right" binning) array[0] = 99 # bins follow conventions for pandas.cut # http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html bins = [0, 1.5, 5] - bin_coords = pd.cut(array['dim_0'], bins).categories - expected = DataArray([1, 5], dims='dim_0_bins', - coords={'dim_0_bins': bin_coords}) + bin_coords = pd.cut(array["dim_0"], bins).categories + expected = DataArray( + [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} + ) # the problem with this is that it overwrites the dimensions of array! # actual = array.groupby('dim_0', bins=bins).sum() - actual = array.groupby_bins('dim_0', bins).apply(lambda x: x.sum()) + actual = array.groupby_bins("dim_0", bins).apply(lambda x: x.sum()) assert_identical(expected, actual) # make sure original array dims are unchanged assert len(array.dim_0) == 4 def test_groupby_bins_empty(self): - array = DataArray(np.arange(4), [('x', range(4))]) + array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty bins = [0, 4, 5] - bin_coords = pd.cut(array['x'], bins).categories - actual = array.groupby_bins('x', bins).sum() - expected = DataArray([6, np.nan], dims='x_bins', - coords={'x_bins': bin_coords}) + bin_coords = pd.cut(array["x"], bins).categories + actual = array.groupby_bins("x", bins).sum() + expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) assert_identical(expected, actual) # make sure original array is unchanged # (was a problem in earlier versions) @@ -2463,284 +2719,282 @@ def test_groupby_bins_empty(self): def test_groupby_bins_multidim(self): array = self.make_groupby_multidim_example_array() bins = [0, 15, 20] - bin_coords = pd.cut(array['lat'].values.flat, bins).categories - expected = DataArray([16, 40], dims='lat_bins', - coords={'lat_bins': bin_coords}) - actual = array.groupby_bins('lat', bins).apply(lambda x: x.sum()) + bin_coords = pd.cut(array["lat"].values.flat, bins).categories + expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) + actual = array.groupby_bins("lat", bins).apply(lambda x: x.sum()) assert_identical(expected, actual) # modify the array coordinates to be non-monotonic after unstacking - array['lat'].data = np.array([[10., 20.], [20., 10.]]) - expected = DataArray([28, 28], dims='lat_bins', - coords={'lat_bins': bin_coords}) - actual = array.groupby_bins('lat', bins).apply(lambda x: x.sum()) + array["lat"].data = np.array([[10.0, 20.0], [20.0, 10.0]]) + expected = DataArray([28, 28], dims="lat_bins", coords={"lat_bins": bin_coords}) + actual = array.groupby_bins("lat", bins).apply(lambda x: x.sum()) assert_identical(expected, actual) def test_groupby_bins_sort(self): data = xr.DataArray( - np.arange(100), dims='x', - coords={'x': np.linspace(-100, 100, num=100)}) - binned_mean = data.groupby_bins('x', bins=11).mean() + np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} + ) + binned_mean = data.groupby_bins("x", bins=11).mean() assert binned_mean.to_index().is_monotonic def test_resample(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - array = DataArray(np.arange(10), [('time', times)]) + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time='24H').mean() - expected = DataArray(array.to_series().resample('24H').mean()) + actual = array.resample(time="24H").mean() + expected = DataArray(array.to_series().resample("24H").mean()) assert_identical(expected, actual) - actual = array.resample(time='24H').reduce(np.mean) + actual = array.resample(time="24H").reduce(np.mean) assert_identical(expected, actual) - actual = array.resample(time='24H', loffset='-12H').mean() - expected = DataArray(array.to_series().resample('24H', loffset='-12H') - .mean()) + actual = array.resample(time="24H", loffset="-12H").mean() + expected = DataArray(array.to_series().resample("24H", loffset="-12H").mean()) assert_identical(expected, actual) - with raises_regex(ValueError, 'index must be monotonic'): - array[[2, 0, 1]].resample(time='1D') + with raises_regex(ValueError, "index must be monotonic"): + array[[2, 0, 1]].resample(time="1D") def test_da_resample_func_args(self): + def func(arg1, arg2, arg3=0.0): + return arg1.mean("time") + arg2 + arg3 - def func(arg1, arg2, arg3=0.): - return arg1.mean('time') + arg2 + arg3 - - times = pd.date_range('2000', periods=3, freq='D') - da = xr.DataArray([1., 1., 1.], coords=[times], dims=['time']) - expected = xr.DataArray([3., 3., 3.], coords=[times], dims=['time']) - actual = da.resample(time='D').apply(func, args=(1.,), arg3=1.) + times = pd.date_range("2000", periods=3, freq="D") + da = xr.DataArray([1.0, 1.0, 1.0], coords=[times], dims=["time"]) + expected = xr.DataArray([3.0, 3.0, 3.0], coords=[times], dims=["time"]) + actual = da.resample(time="D").apply(func, args=(1.0,), arg3=1.0) assert_identical(actual, expected) def test_resample_first(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - array = DataArray(np.arange(10), [('time', times)]) + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time='1D').first() - expected = DataArray([0, 4, 8], [('time', times[::4])]) + actual = array.resample(time="1D").first() + expected = DataArray([0, 4, 8], [("time", times[::4])]) assert_identical(expected, actual) # verify that labels don't use the first value - actual = array.resample(time='24H').first() - expected = DataArray(array.to_series().resample('24H').first()) + actual = array.resample(time="24H").first() + expected = DataArray(array.to_series().resample("24H").first()) assert_identical(expected, actual) # missing values array = array.astype(float) array[:2] = np.nan - actual = array.resample(time='1D').first() - expected = DataArray([2, 4, 8], [('time', times[::4])]) + actual = array.resample(time="1D").first() + expected = DataArray([2, 4, 8], [("time", times[::4])]) assert_identical(expected, actual) - actual = array.resample(time='1D').first(skipna=False) - expected = DataArray([np.nan, 4, 8], [('time', times[::4])]) + actual = array.resample(time="1D").first(skipna=False) + expected = DataArray([np.nan, 4, 8], [("time", times[::4])]) assert_identical(expected, actual) # regression test for http://stackoverflow.com/questions/33158558/ - array = Dataset({'time': times})['time'] - actual = array.resample(time='1D').last() - expected_times = pd.to_datetime(['2000-01-01T18', '2000-01-02T18', - '2000-01-03T06']) - expected = DataArray(expected_times, [('time', times[::4])], - name='time') + array = Dataset({"time": times})["time"] + actual = array.resample(time="1D").last() + expected_times = pd.to_datetime( + ["2000-01-01T18", "2000-01-02T18", "2000-01-03T06"] + ) + expected = DataArray(expected_times, [("time", times[::4])], name="time") assert_identical(expected, actual) def test_resample_bad_resample_dim(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - array = DataArray(np.arange(10), [('__resample_dim__', times)]) - with raises_regex(ValueError, 'Proxy resampling dimension'): - array.resample(**{'__resample_dim__': '1D'}).first() + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.arange(10), [("__resample_dim__", times)]) + with raises_regex(ValueError, "Proxy resampling dimension"): + array.resample(**{"__resample_dim__": "1D"}).first() @requires_scipy def test_resample_drop_nondim_coords(self): xs = np.arange(6) ys = np.arange(3) - times = pd.date_range('2000-01-01', freq='6H', periods=5) + times = pd.date_range("2000-01-01", freq="6H", periods=5) data = np.tile(np.arange(5), (6, 3, 1)) xx, yy = np.meshgrid(xs * 5, ys * 2.5) tt = np.arange(len(times), dtype=int) - array = DataArray(data, - {'time': times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) - xcoord = DataArray(xx.T, {'x': xs, 'y': ys}, ('x', 'y')) - ycoord = DataArray(yy.T, {'x': xs, 'y': ys}, ('x', 'y')) - tcoord = DataArray(tt, {'time': times}, ('time', )) - ds = Dataset({'data': array, 'xc': xcoord, - 'yc': ycoord, 'tc': tcoord}) - ds = ds.set_coords(['xc', 'yc', 'tc']) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) + ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) + tcoord = DataArray(tt, {"time": times}, ("time",)) + ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) + ds = ds.set_coords(["xc", "yc", "tc"]) # Select the data now, with the auxiliary coordinates in place - array = ds['data'] + array = ds["data"] # Re-sample - actual = array.resample( - time="12H", restore_coord_dims=True).mean('time') - assert 'tc' not in actual.coords + actual = array.resample(time="12H", restore_coord_dims=True).mean("time") + assert "tc" not in actual.coords # Up-sample - filling - actual = array.resample( - time="1H", restore_coord_dims=True).ffill() - assert 'tc' not in actual.coords + actual = array.resample(time="1H", restore_coord_dims=True).ffill() + assert "tc" not in actual.coords # Up-sample - interpolation - actual = array.resample( - time="1H", restore_coord_dims=True).interpolate('linear') - assert 'tc' not in actual.coords + actual = array.resample(time="1H", restore_coord_dims=True).interpolate( + "linear" + ) + assert "tc" not in actual.coords def test_resample_keep_attrs(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - array = DataArray(np.ones(10), [('time', times)]) - array.attrs['meta'] = 'data' + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.ones(10), [("time", times)]) + array.attrs["meta"] = "data" - result = array.resample(time='1D').mean(keep_attrs=True) - expected = DataArray([1, 1, 1], [('time', times[::4])], - attrs=array.attrs) + result = array.resample(time="1D").mean(keep_attrs=True) + expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) assert_identical(result, expected) def test_resample_skipna(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - array = DataArray(np.ones(10), [('time', times)]) + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.ones(10), [("time", times)]) array[1] = np.nan - result = array.resample(time='1D').mean(skipna=False) - expected = DataArray([np.nan, 1, 1], [('time', times[::4])]) + result = array.resample(time="1D").mean(skipna=False) + expected = DataArray([np.nan, 1, 1], [("time", times[::4])]) assert_identical(result, expected) def test_upsample(self): - times = pd.date_range('2000-01-01', freq='6H', periods=5) - array = DataArray(np.arange(5), [('time', times)]) + times = pd.date_range("2000-01-01", freq="6H", periods=5) + array = DataArray(np.arange(5), [("time", times)]) # Forward-fill - actual = array.resample(time='3H').ffill() - expected = DataArray(array.to_series().resample('3H').ffill()) + actual = array.resample(time="3H").ffill() + expected = DataArray(array.to_series().resample("3H").ffill()) assert_identical(expected, actual) # Backward-fill - actual = array.resample(time='3H').bfill() - expected = DataArray(array.to_series().resample('3H').bfill()) + actual = array.resample(time="3H").bfill() + expected = DataArray(array.to_series().resample("3H").bfill()) assert_identical(expected, actual) # As frequency - actual = array.resample(time='3H').asfreq() - expected = DataArray(array.to_series().resample('3H').asfreq()) + actual = array.resample(time="3H").asfreq() + expected = DataArray(array.to_series().resample("3H").asfreq()) assert_identical(expected, actual) # Pad - actual = array.resample(time='3H').pad() - expected = DataArray(array.to_series().resample('3H').pad()) + actual = array.resample(time="3H").pad() + expected = DataArray(array.to_series().resample("3H").pad()) assert_identical(expected, actual) # Nearest - rs = array.resample(time='3H') + rs = array.resample(time="3H") actual = rs.nearest() new_times = rs._full_index - expected = DataArray( - array.reindex(time=new_times, method='nearest') - ) + expected = DataArray(array.reindex(time=new_times, method="nearest")) assert_identical(expected, actual) def test_upsample_nd(self): # Same as before, but now we try on multi-dimensional DataArrays. xs = np.arange(6) ys = np.arange(3) - times = pd.date_range('2000-01-01', freq='6H', periods=5) + times = pd.date_range("2000-01-01", freq="6H", periods=5) data = np.tile(np.arange(5), (6, 3, 1)) - array = DataArray(data, - {'time': times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) # Forward-fill - actual = array.resample(time='3H').ffill() + actual = array.resample(time="3H").ffill() expected_data = np.repeat(data, 2, axis=-1) - expected_times = times.to_series().resample('3H').asfreq().index - expected_data = expected_data[..., :len(expected_times)] - expected = DataArray(expected_data, - {'time': expected_times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + expected_times = times.to_series().resample("3H").asfreq().index + expected_data = expected_data[..., : len(expected_times)] + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) assert_identical(expected, actual) # Backward-fill - actual = array.resample(time='3H').ffill() + actual = array.resample(time="3H").ffill() expected_data = np.repeat(np.flipud(data.T).T, 2, axis=-1) expected_data = np.flipud(expected_data.T).T - expected_times = times.to_series().resample('3H').asfreq().index - expected_data = expected_data[..., :len(expected_times)] - expected = DataArray(expected_data, - {'time': expected_times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + expected_times = times.to_series().resample("3H").asfreq().index + expected_data = expected_data[..., : len(expected_times)] + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) assert_identical(expected, actual) # As frequency - actual = array.resample(time='3H').asfreq() + actual = array.resample(time="3H").asfreq() expected_data = np.repeat(data, 2, axis=-1).astype(float)[..., :-1] expected_data[..., 1::2] = np.nan - expected_times = times.to_series().resample('3H').asfreq().index - expected = DataArray(expected_data, - {'time': expected_times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + expected_times = times.to_series().resample("3H").asfreq().index + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) assert_identical(expected, actual) # Pad - actual = array.resample(time='3H').pad() + actual = array.resample(time="3H").pad() expected_data = np.repeat(data, 2, axis=-1) expected_data[..., 1::2] = expected_data[..., ::2] expected_data = expected_data[..., :-1] - expected_times = times.to_series().resample('3H').asfreq().index - expected = DataArray(expected_data, - {'time': expected_times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + expected_times = times.to_series().resample("3H").asfreq().index + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) assert_identical(expected, actual) def test_upsample_tolerance(self): # Test tolerance keyword for upsample methods bfill, pad, nearest - times = pd.date_range('2000-01-01', freq='1D', periods=2) - times_upsampled = pd.date_range('2000-01-01', freq='6H', periods=5) - array = DataArray(np.arange(2), [('time', times)]) + times = pd.date_range("2000-01-01", freq="1D", periods=2) + times_upsampled = pd.date_range("2000-01-01", freq="6H", periods=5) + array = DataArray(np.arange(2), [("time", times)]) # Forward fill - actual = array.resample(time='6H').ffill(tolerance='12H') - expected = DataArray([0., 0., 0., np.nan, 1.], - [('time', times_upsampled)]) + actual = array.resample(time="6H").ffill(tolerance="12H") + expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Backward fill - actual = array.resample(time='6H').bfill(tolerance='12H') - expected = DataArray([0., np.nan, 1., 1., 1.], - [('time', times_upsampled)]) + actual = array.resample(time="6H").bfill(tolerance="12H") + expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Nearest - actual = array.resample(time='6H').nearest(tolerance='6H') - expected = DataArray([0, 0, np.nan, 1, 1], - [('time', times_upsampled)]) + actual = array.resample(time="6H").nearest(tolerance="6H") + expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) assert_identical(expected, actual) @requires_scipy def test_upsample_interpolate(self): from scipy.interpolate import interp1d + xs = np.arange(6) ys = np.arange(3) - times = pd.date_range('2000-01-01', freq='6H', periods=5) + times = pd.date_range("2000-01-01", freq="6H", periods=5) - z = np.arange(5)**2 + z = np.arange(5) ** 2 data = np.tile(z, (6, 3, 1)) - array = DataArray(data, - {'time': times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - expected_times = times.to_series().resample('1H').asfreq().index + expected_times = times.to_series().resample("1H").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) - for kind in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', - 'cubic']: - actual = array.resample(time='1H').interpolate(kind) - f = interp1d(np.arange(len(times)), data, kind=kind, axis=-1, - bounds_error=True, assume_sorted=True) + for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + actual = array.resample(time="1H").interpolate(kind) + f = interp1d( + np.arange(len(times)), + data, + kind=kind, + axis=-1, + bounds_error=True, + assume_sorted=True, + ) expected_data = f(new_times_idx) - expected = DataArray(expected_data, - {'time': expected_times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) # Use AllClose because there are some small differences in how # we upsample timeseries versus the integer indexing as I've # done here due to floating point arithmetic @@ -2748,52 +3002,61 @@ def test_upsample_interpolate(self): @requires_scipy def test_upsample_interpolate_bug_2197(self): - dates = pd.date_range('2007-02-01', '2007-03-01', freq='D') - da = xr.DataArray(np.arange(len(dates)), [('time', dates)]) - result = da.resample(time='M').interpolate('linear') - expected_times = np.array([np.datetime64('2007-02-28'), - np.datetime64('2007-03-31')]) - expected = xr.DataArray([27., np.nan], [('time', expected_times)]) + dates = pd.date_range("2007-02-01", "2007-03-01", freq="D") + da = xr.DataArray(np.arange(len(dates)), [("time", dates)]) + result = da.resample(time="M").interpolate("linear") + expected_times = np.array( + [np.datetime64("2007-02-28"), np.datetime64("2007-03-31")] + ) + expected = xr.DataArray([27.0, np.nan], [("time", expected_times)]) assert_equal(result, expected) @requires_scipy def test_upsample_interpolate_regression_1605(self): - dates = pd.date_range('2016-01-01', '2016-03-31', freq='1D') - expected = xr.DataArray(np.random.random((len(dates), 2, 3)), - dims=('time', 'x', 'y'), - coords={'time': dates}) - actual = expected.resample(time='1D').interpolate('linear') + dates = pd.date_range("2016-01-01", "2016-03-31", freq="1D") + expected = xr.DataArray( + np.random.random((len(dates), 2, 3)), + dims=("time", "x", "y"), + coords={"time": dates}, + ) + actual = expected.resample(time="1D").interpolate("linear") assert_allclose(actual, expected, rtol=1e-16) @requires_dask @requires_scipy def test_upsample_interpolate_dask(self): from scipy.interpolate import interp1d + xs = np.arange(6) ys = np.arange(3) - times = pd.date_range('2000-01-01', freq='6H', periods=5) + times = pd.date_range("2000-01-01", freq="6H", periods=5) - z = np.arange(5)**2 + z = np.arange(5) ** 2 data = np.tile(z, (6, 3, 1)) - array = DataArray(data, - {'time': times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) - chunks = {'x': 2, 'y': 1} + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + chunks = {"x": 2, "y": 1} - expected_times = times.to_series().resample('1H').asfreq().index + expected_times = times.to_series().resample("1H").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) - for kind in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', - 'cubic']: - actual = array.chunk(chunks).resample(time='1H').interpolate(kind) + for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + actual = array.chunk(chunks).resample(time="1H").interpolate(kind) actual = actual.compute() - f = interp1d(np.arange(len(times)), data, kind=kind, axis=-1, - bounds_error=True, assume_sorted=True) + f = interp1d( + np.arange(len(times)), + data, + kind=kind, + axis=-1, + bounds_error=True, + assume_sorted=True, + ) expected_data = f(new_times_idx) - expected = DataArray(expected_data, - {'time': expected_times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) # Use AllClose because there are some small differences in how # we upsample timeseries versus the integer indexing as I've # done here due to floating point arithmetic @@ -2801,14 +3064,16 @@ def test_upsample_interpolate_dask(self): # Check that an error is raised if an attempt is made to interpolate # over a chunked dimension - with raises_regex(NotImplementedError, - 'Chunking along the dimension to be interpolated'): - array.chunk({'time': 1}).resample(time='1H').interpolate('linear') + with raises_regex( + NotImplementedError, "Chunking along the dimension to be interpolated" + ): + array.chunk({"time": 1}).resample(time="1H").interpolate("linear") def test_align(self): - array = DataArray(np.random.random((6, 8)), - coords={'x': list('abcdef')}, dims=['x', 'y']) - array1, array2 = align(array, array[:5], join='inner') + array = DataArray( + np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"] + ) + array1, array2 = align(array, array[:5], join="inner") assert_identical(array1, array[:5]) assert_identical(array2, array[:5]) @@ -2816,30 +3081,30 @@ def test_align_dtype(self): # regression test for #264 x1 = np.arange(30) x2 = np.arange(5, 35) - a = DataArray(np.random.random((30,)).astype(np.float32), [('x', x1)]) - b = DataArray(np.random.random((30,)).astype(np.float32), [('x', x2)]) - c, d = align(a, b, join='outer') + a = DataArray(np.random.random((30,)).astype(np.float32), [("x", x1)]) + b = DataArray(np.random.random((30,)).astype(np.float32), [("x", x2)]) + c, d = align(a, b, join="outer") assert c.dtype == np.float32 def test_align_copy(self): - x = DataArray([1, 2, 3], coords=[('a', [1, 2, 3])]) - y = DataArray([1, 2], coords=[('a', [3, 1])]) + x = DataArray([1, 2, 3], coords=[("a", [1, 2, 3])]) + y = DataArray([1, 2], coords=[("a", [3, 1])]) expected_x2 = x - expected_y2 = DataArray([2, np.nan, 1], coords=[('a', [1, 2, 3])]) + expected_y2 = DataArray([2, np.nan, 1], coords=[("a", [1, 2, 3])]) - x2, y2 = align(x, y, join='outer', copy=False) + x2, y2 = align(x, y, join="outer", copy=False) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) assert source_ndarray(x2.data) is source_ndarray(x.data) - x2, y2 = align(x, y, join='outer', copy=True) + x2, y2 = align(x, y, join="outer", copy=True) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) assert source_ndarray(x2.data) is not source_ndarray(x.data) # Trivial align - 1 element - x = DataArray([1, 2, 3], coords=[('a', [1, 2, 3])]) + x = DataArray([1, 2, 3], coords=[("a", [1, 2, 3])]) x2, = align(x, copy=False) assert_identical(x, x2) assert source_ndarray(x2.data) is source_ndarray(x.data) @@ -2849,78 +3114,79 @@ def test_align_copy(self): assert source_ndarray(x2.data) is not source_ndarray(x.data) def test_align_exclude(self): - x = DataArray([[1, 2], [3, 4]], - coords=[('a', [-1, -2]), ('b', [3, 4])]) - y = DataArray([[1, 2], [3, 4]], - coords=[('a', [-1, 20]), ('b', [5, 6])]) - z = DataArray([1], dims=['a'], coords={'a': [20], 'b': 7}) - - x2, y2, z2 = align(x, y, z, join='outer', exclude=['b']) - expected_x2 = DataArray([[3, 4], [1, 2], [np.nan, np.nan]], - coords=[('a', [-2, -1, 20]), ('b', [3, 4])]) - expected_y2 = DataArray([[np.nan, np.nan], [1, 2], [3, 4]], - coords=[('a', [-2, -1, 20]), ('b', [5, 6])]) - expected_z2 = DataArray([np.nan, np.nan, 1], dims=['a'], - coords={'a': [-2, -1, 20], 'b': 7}) + x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])]) + y = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, 20]), ("b", [5, 6])]) + z = DataArray([1], dims=["a"], coords={"a": [20], "b": 7}) + + x2, y2, z2 = align(x, y, z, join="outer", exclude=["b"]) + expected_x2 = DataArray( + [[3, 4], [1, 2], [np.nan, np.nan]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + expected_y2 = DataArray( + [[np.nan, np.nan], [1, 2], [3, 4]], + coords=[("a", [-2, -1, 20]), ("b", [5, 6])], + ) + expected_z2 = DataArray( + [np.nan, np.nan, 1], dims=["a"], coords={"a": [-2, -1, 20], "b": 7} + ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) assert_identical(expected_z2, z2) def test_align_indexes(self): - x = DataArray([1, 2, 3], coords=[('a', [-1, 10, -2])]) - y = DataArray([1, 2], coords=[('a', [-2, -1])]) + x = DataArray([1, 2, 3], coords=[("a", [-1, 10, -2])]) + y = DataArray([1, 2], coords=[("a", [-2, -1])]) - x2, y2 = align(x, y, join='outer', indexes={'a': [10, -1, -2]}) - expected_x2 = DataArray([2, 1, 3], coords=[('a', [10, -1, -2])]) - expected_y2 = DataArray([np.nan, 2, 1], coords=[('a', [10, -1, -2])]) + x2, y2 = align(x, y, join="outer", indexes={"a": [10, -1, -2]}) + expected_x2 = DataArray([2, 1, 3], coords=[("a", [10, -1, -2])]) + expected_y2 = DataArray([np.nan, 2, 1], coords=[("a", [10, -1, -2])]) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) - x2, = align(x, join='outer', indexes={'a': [-2, 7, 10, -1]}) - expected_x2 = DataArray([3, np.nan, 2, 1], - coords=[('a', [-2, 7, 10, -1])]) + x2, = align(x, join="outer", indexes={"a": [-2, 7, 10, -1]}) + expected_x2 = DataArray([3, np.nan, 2, 1], coords=[("a", [-2, 7, 10, -1])]) assert_identical(expected_x2, x2) def test_align_without_indexes_exclude(self): - arrays = [DataArray([1, 2, 3], dims=['x']), - DataArray([1, 2], dims=['x'])] - result0, result1 = align(*arrays, exclude=['x']) + arrays = [DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])] + result0, result1 = align(*arrays, exclude=["x"]) assert_identical(result0, arrays[0]) assert_identical(result1, arrays[1]) def test_align_mixed_indexes(self): - array_no_coord = DataArray([1, 2], dims=['x']) - array_with_coord = DataArray([1, 2], coords=[('x', ['a', 'b'])]) + array_no_coord = DataArray([1, 2], dims=["x"]) + array_with_coord = DataArray([1, 2], coords=[("x", ["a", "b"])]) result0, result1 = align(array_no_coord, array_with_coord) assert_identical(result0, array_with_coord) assert_identical(result1, array_with_coord) - result0, result1 = align(array_no_coord, array_with_coord, - exclude=['x']) + result0, result1 = align(array_no_coord, array_with_coord, exclude=["x"]) assert_identical(result0, array_no_coord) assert_identical(result1, array_with_coord) def test_align_without_indexes_errors(self): - with raises_regex(ValueError, 'cannot be aligned'): - align(DataArray([1, 2, 3], dims=['x']), - DataArray([1, 2], dims=['x'])) + with raises_regex(ValueError, "cannot be aligned"): + align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) - with raises_regex(ValueError, 'cannot be aligned'): - align(DataArray([1, 2, 3], dims=['x']), - DataArray([1, 2], coords=[('x', [0, 1])])) + with raises_regex(ValueError, "cannot be aligned"): + align( + DataArray([1, 2, 3], dims=["x"]), + DataArray([1, 2], coords=[("x", [0, 1])]), + ) def test_broadcast_arrays(self): - x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x') - y = DataArray([1, 2], coords=[('b', [3, 4])], name='y') + x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") + y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") x2, y2 = broadcast(x, y) - expected_coords = [('a', [-1, -2]), ('b', [3, 4])] - expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name='x') - expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name='y') + expected_coords = [("a", [-1, -2]), ("b", [3, 4])] + expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name="x") + expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name="y") assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) - x = DataArray(np.random.randn(2, 3), dims=['a', 'b']) - y = DataArray(np.random.randn(3, 2), dims=['b', 'a']) + x = DataArray(np.random.randn(2, 3), dims=["a", "b"]) + y = DataArray(np.random.randn(3, 2), dims=["b", "a"]) x2, y2 = broadcast(x, y) expected_x2 = x expected_y2 = y.T @@ -2929,13 +3195,16 @@ def test_broadcast_arrays(self): def test_broadcast_arrays_misaligned(self): # broadcast on misaligned coords must auto-align - x = DataArray([[1, 2], [3, 4]], - coords=[('a', [-1, -2]), ('b', [3, 4])]) - y = DataArray([1, 2], coords=[('a', [-1, 20])]) - expected_x2 = DataArray([[3, 4], [1, 2], [np.nan, np.nan]], - coords=[('a', [-2, -1, 20]), ('b', [3, 4])]) - expected_y2 = DataArray([[np.nan, np.nan], [1, 1], [2, 2]], - coords=[('a', [-2, -1, 20]), ('b', [3, 4])]) + x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])]) + y = DataArray([1, 2], coords=[("a", [-1, 20])]) + expected_x2 = DataArray( + [[3, 4], [1, 2], [np.nan, np.nan]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + expected_y2 = DataArray( + [[np.nan, np.nan], [1, 1], [2, 2]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) x2, y2 = broadcast(x, y) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) @@ -2943,10 +3212,10 @@ def test_broadcast_arrays_misaligned(self): def test_broadcast_arrays_nocopy(self): # Test that input data is not copied over in case # no alteration is needed - x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x') - y = DataArray(3, name='y') - expected_x2 = DataArray([1, 2], coords=[('a', [-1, -2])], name='x') - expected_y2 = DataArray([3, 3], coords=[('a', [-1, -2])], name='y') + x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") + y = DataArray(3, name="y") + expected_x2 = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") + expected_y2 = DataArray([3, 3], coords=[("a", [-1, -2])], name="y") x2, y2 = broadcast(x, y) assert_identical(expected_x2, x2) @@ -2959,30 +3228,32 @@ def test_broadcast_arrays_nocopy(self): assert source_ndarray(x2.data) is source_ndarray(x.data) def test_broadcast_arrays_exclude(self): - x = DataArray([[1, 2], [3, 4]], - coords=[('a', [-1, -2]), ('b', [3, 4])]) - y = DataArray([1, 2], coords=[('a', [-1, 20])]) - z = DataArray(5, coords={'b': 5}) - - x2, y2, z2 = broadcast(x, y, z, exclude=['b']) - expected_x2 = DataArray([[3, 4], [1, 2], [np.nan, np.nan]], - coords=[('a', [-2, -1, 20]), ('b', [3, 4])]) - expected_y2 = DataArray([np.nan, 1, 2], coords=[('a', [-2, -1, 20])]) - expected_z2 = DataArray([5, 5, 5], dims=['a'], - coords={'a': [-2, -1, 20], 'b': 5}) + x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])]) + y = DataArray([1, 2], coords=[("a", [-1, 20])]) + z = DataArray(5, coords={"b": 5}) + + x2, y2, z2 = broadcast(x, y, z, exclude=["b"]) + expected_x2 = DataArray( + [[3, 4], [1, 2], [np.nan, np.nan]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + expected_y2 = DataArray([np.nan, 1, 2], coords=[("a", [-2, -1, 20])]) + expected_z2 = DataArray( + [5, 5, 5], dims=["a"], coords={"a": [-2, -1, 20], "b": 5} + ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) assert_identical(expected_z2, z2) def test_broadcast_coordinates(self): # regression test for GH649 - ds = Dataset({'a': (['x', 'y'], np.ones((5, 6)))}) + ds = Dataset({"a": (["x", "y"], np.ones((5, 6)))}) x_bc, y_bc, a_bc = broadcast(ds.x, ds.y, ds.a) assert_identical(ds.a, a_bc) - X, Y = np.meshgrid(np.arange(5), np.arange(6), indexing='ij') - exp_x = DataArray(X, dims=['x', 'y'], name='x') - exp_y = DataArray(Y, dims=['x', 'y'], name='y') + X, Y = np.meshgrid(np.arange(5), np.arange(6), indexing="ij") + exp_x = DataArray(X, dims=["x", "y"], name="x") + exp_y = DataArray(Y, dims=["x", "y"], name="y") assert_identical(exp_x, x_bc) assert_identical(exp_y, y_bc) @@ -2994,89 +3265,87 @@ def test_to_pandas(self): # 1d values = np.random.randn(3) - index = pd.Index(['a', 'b', 'c'], name='x') + index = pd.Index(["a", "b", "c"], name="x") da = DataArray(values, coords=[index]) actual = da.to_pandas() assert_array_equal(actual.values, values) assert_array_equal(actual.index, index) - assert_array_equal(actual.index.name, 'x') + assert_array_equal(actual.index.name, "x") # 2d values = np.random.randn(3, 2) - da = DataArray(values, coords=[('x', ['a', 'b', 'c']), ('y', [0, 1])], - name='foo') + da = DataArray( + values, coords=[("x", ["a", "b", "c"]), ("y", [0, 1])], name="foo" + ) actual = da.to_pandas() assert_array_equal(actual.values, values) - assert_array_equal(actual.index, ['a', 'b', 'c']) + assert_array_equal(actual.index, ["a", "b", "c"]) assert_array_equal(actual.columns, [0, 1]) # roundtrips for shape in [(3,), (3, 4), (3, 4, 5)]: - if len(shape) > 2 and not LooseVersion(pd.__version__) < '0.25.0': + if len(shape) > 2 and not LooseVersion(pd.__version__) < "0.25.0": continue - dims = list('abc')[:len(shape)] + dims = list("abc")[: len(shape)] da = DataArray(np.random.randn(*shape), dims=dims) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + warnings.filterwarnings("ignore", r"\W*Panel is deprecated") roundtripped = DataArray(da.to_pandas()).drop(dims) assert_identical(da, roundtripped) - with raises_regex(ValueError, 'cannot convert'): + with raises_regex(ValueError, "cannot convert"): DataArray(np.random.randn(1, 2, 3, 4, 5)).to_pandas() def test_to_dataframe(self): # regression test for #260 - arr = DataArray(np.random.randn(3, 4), - [('B', [1, 2, 3]), ('A', list('cdef'))], name='foo') + arr = DataArray( + np.random.randn(3, 4), [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo" + ) expected = arr.to_series() - actual = arr.to_dataframe()['foo'] + actual = arr.to_dataframe()["foo"] assert_array_equal(expected.values, actual.values) assert_array_equal(expected.name, actual.name) assert_array_equal(expected.index.values, actual.index.values) # regression test for coords with different dimensions - arr.coords['C'] = ('B', [-1, -2, -3]) + arr.coords["C"] = ("B", [-1, -2, -3]) expected = arr.to_series().to_frame() - expected['C'] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[['C', 'foo']] + expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected = expected[["C", "foo"]] actual = arr.to_dataframe() assert_array_equal(expected.values, actual.values) assert_array_equal(expected.columns.values, actual.columns.values) assert_array_equal(expected.index.values, actual.index.values) arr.name = None # unnamed - with raises_regex(ValueError, 'unnamed'): + with raises_regex(ValueError, "unnamed"): arr.to_dataframe() def test_to_pandas_name_matches_coordinate(self): # coordinate with same name as array - arr = DataArray([1, 2, 3], dims='x', name='x') + arr = DataArray([1, 2, 3], dims="x", name="x") series = arr.to_series() assert_array_equal([1, 2, 3], series.values) assert_array_equal([0, 1, 2], series.index.values) - assert 'x' == series.name - assert 'x' == series.index.name + assert "x" == series.name + assert "x" == series.index.name frame = arr.to_dataframe() expected = series.to_frame() assert expected.equals(frame) def test_to_and_from_series(self): - expected = self.dv.to_dataframe()['foo'] + expected = self.dv.to_dataframe()["foo"] actual = self.dv.to_series() assert_array_equal(expected.values, actual.values) assert_array_equal(expected.index.values, actual.index.values) - assert 'foo' == actual.name + assert "foo" == actual.name # test roundtrip - assert_identical( - self.dv, - DataArray.from_series(actual).drop(['x', 'y'])) + assert_identical(self.dv, DataArray.from_series(actual).drop(["x", "y"])) # test name is None actual.name = None expected_da = self.dv.rename(None) - assert_identical( - expected_da, - DataArray.from_series(actual).drop(['x', 'y'])) + assert_identical(expected_da, DataArray.from_series(actual).drop(["x", "y"])) def test_to_and_from_empty_series(self): # GH697 @@ -3089,23 +3358,24 @@ def test_to_and_from_empty_series(self): def test_series_categorical_index(self): # regression test for GH700 - if not hasattr(pd, 'CategoricalIndex'): - pytest.skip('requires pandas with CategoricalIndex') + if not hasattr(pd, "CategoricalIndex"): + pytest.skip("requires pandas with CategoricalIndex") - s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list('aabbc'))) + s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc"))) arr = DataArray(s) assert "'a'" in repr(arr) # should not error def test_to_and_from_dict(self): - array = DataArray(np.random.randn(2, 3), {'x': ['a', 'b']}, ['x', 'y'], - name='foo') - expected = {'name': 'foo', - 'dims': ('x', 'y'), - 'data': array.values.tolist(), - 'attrs': {}, - 'coords': {'x': {'dims': ('x',), - 'data': ['a', 'b'], - 'attrs': {}}}} + array = DataArray( + np.random.randn(2, 3), {"x": ["a", "b"]}, ["x", "y"], name="foo" + ) + expected = { + "name": "foo", + "dims": ("x", "y"), + "data": array.values.tolist(), + "attrs": {}, + "coords": {"x": {"dims": ("x",), "data": ["a", "b"], "attrs": {}}}, + } actual = array.to_dict() # check that they are identical @@ -3115,77 +3385,82 @@ def test_to_and_from_dict(self): assert_identical(array, DataArray.from_dict(actual)) # a more bare bones representation still roundtrips - d = {'name': 'foo', - 'dims': ('x', 'y'), - 'data': array.values.tolist(), - 'coords': {'x': {'dims': 'x', 'data': ['a', 'b']}}} + d = { + "name": "foo", + "dims": ("x", "y"), + "data": array.values.tolist(), + "coords": {"x": {"dims": "x", "data": ["a", "b"]}}, + } assert_identical(array, DataArray.from_dict(d)) # and the most bare bones representation still roundtrips - d = {'name': 'foo', 'dims': ('x', 'y'), 'data': array.values} - assert_identical(array.drop('x'), DataArray.from_dict(d)) + d = {"name": "foo", "dims": ("x", "y"), "data": array.values} + assert_identical(array.drop("x"), DataArray.from_dict(d)) # missing a dims in the coords - d = {'dims': ('x', 'y'), - 'data': array.values, - 'coords': {'x': {'data': ['a', 'b']}}} + d = { + "dims": ("x", "y"), + "data": array.values, + "coords": {"x": {"data": ["a", "b"]}}, + } with raises_regex( - ValueError, - "cannot convert dict when coords are missing the key 'dims'"): + ValueError, "cannot convert dict when coords are missing the key 'dims'" + ): DataArray.from_dict(d) # this one is missing some necessary information - d = {'dims': ('t')} - with raises_regex( - ValueError, "cannot convert dict without the key 'data'"): + d = {"dims": ("t")} + with raises_regex(ValueError, "cannot convert dict without the key 'data'"): DataArray.from_dict(d) # check the data=False option expected_no_data = expected.copy() - del expected_no_data['data'] - del expected_no_data['coords']['x']['data'] - endiantype = 'U1' - expected_no_data['coords']['x'].update({'dtype': endiantype, - 'shape': (2,)}) - expected_no_data.update({'dtype': 'float64', 'shape': (2, 3)}) + del expected_no_data["data"] + del expected_no_data["coords"]["x"]["data"] + endiantype = "U1" + expected_no_data["coords"]["x"].update({"dtype": endiantype, "shape": (2,)}) + expected_no_data.update({"dtype": "float64", "shape": (2, 3)}) actual_no_data = array.to_dict(data=False) assert expected_no_data == actual_no_data def test_to_and_from_dict_with_time_dim(self): x = np.random.randn(10, 3) - t = pd.date_range('20130101', periods=10) + t = pd.date_range("20130101", periods=10) lat = [77.7, 83.2, 76] - da = DataArray(x, {'t': t, 'lat': lat}, dims=['t', 'lat']) + da = DataArray(x, {"t": t, "lat": lat}, dims=["t", "lat"]) roundtripped = DataArray.from_dict(da.to_dict()) assert_identical(da, roundtripped) def test_to_and_from_dict_with_nan_nat(self): y = np.random.randn(10, 3) y[2] = np.nan - t = pd.Series(pd.date_range('20130101', periods=10)) + t = pd.Series(pd.date_range("20130101", periods=10)) t[2] = np.nan lat = [77.7, 83.2, 76] - da = DataArray(y, {'t': t, 'lat': lat}, dims=['t', 'lat']) + da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"]) roundtripped = DataArray.from_dict(da.to_dict()) assert_identical(da, roundtripped) def test_to_dict_with_numpy_attrs(self): # this doesn't need to roundtrip x = np.random.randn(10, 3) - t = list('abcdefghij') + t = list("abcdefghij") lat = [77.7, 83.2, 76] - attrs = {'created': np.float64(1998), - 'coords': np.array([37, -110.1, 100]), - 'maintainer': 'bar'} - da = DataArray(x, {'t': t, 'lat': lat}, dims=['t', 'lat'], - attrs=attrs) - expected_attrs = {'created': attrs['created'].item(), - 'coords': attrs['coords'].tolist(), - 'maintainer': 'bar'} + attrs = { + "created": np.float64(1998), + "coords": np.array([37, -110.1, 100]), + "maintainer": "bar", + } + da = DataArray(x, {"t": t, "lat": lat}, dims=["t", "lat"], attrs=attrs) + expected_attrs = { + "created": attrs["created"].item(), + "coords": attrs["coords"].tolist(), + "maintainer": "bar", + } actual = da.to_dict() # check that they are identical - assert expected_attrs == actual['attrs'] + assert expected_attrs == actual["attrs"] def test_to_masked_array(self): rs = np.random.RandomState(44) @@ -3210,14 +3485,14 @@ def test_to_masked_array(self): # Test that copy=False gives access to values masked_array = da.to_masked_array(copy=False) - masked_array[0, 0] = 10. - assert masked_array[0, 0] == 10. - assert da[0, 0].values == 10. + masked_array[0, 0] = 10.0 + assert masked_array[0, 0] == 10.0 + assert da[0, 0].values == 10.0 assert masked_array.base is da.values assert isinstance(masked_array, np.ma.MaskedArray) # Test with some odd arrays - for v in [4, np.nan, True, '4', 'four']: + for v in [4, np.nan, True, "4", "four"]: da = DataArray(v) ma = da.to_masked_array() assert isinstance(ma, np.ma.MaskedArray) @@ -3231,15 +3506,21 @@ def test_to_masked_array(self): def test_to_and_from_cdms2_classic(self): """Classic with 1D axes""" - pytest.importorskip('cdms2') + pytest.importorskip("cdms2") original = DataArray( np.arange(6).reshape(2, 3), - [('distance', [-2, 2], {'units': 'meters'}), - ('time', pd.date_range('2000-01-01', periods=3))], - name='foo', attrs={'baz': 123}) - expected_coords = [IndexVariable('distance', [-2, 2]), - IndexVariable('time', [0, 1, 2])] + [ + ("distance", [-2, 2], {"units": "meters"}), + ("time", pd.date_range("2000-01-01", periods=3)), + ], + name="foo", + attrs={"baz": 123}, + ) + expected_coords = [ + IndexVariable("distance", [-2, 2]), + IndexVariable("time", [0, 1, 2]), + ] actual = original.to_cdms2() assert_array_equal(actual.asma(), original) assert actual.id == original.name @@ -3247,11 +3528,11 @@ def test_to_and_from_cdms2_classic(self): for axis, coord in zip(actual.getAxisList(), expected_coords): assert axis.id == coord.name assert_array_equal(axis, coord.values) - assert actual.baz == original.attrs['baz'] + assert actual.baz == original.attrs["baz"] component_times = actual.getAxis(1).asComponentTime() assert len(component_times) == 3 - assert str(component_times[0]) == '2000-1-1 0:0:0.0' + assert str(component_times[0]) == "2000-1-1 0:0:0.0" roundtripped = DataArray.from_cdms2(actual) assert_identical(original, roundtripped) @@ -3260,194 +3541,194 @@ def test_to_and_from_cdms2_classic(self): assert original.dims == back.dims assert original.coords.keys() == back.coords.keys() for coord_name in original.coords.keys(): - assert_array_equal(original.coords[coord_name], - back.coords[coord_name]) + assert_array_equal(original.coords[coord_name], back.coords[coord_name]) def test_to_and_from_cdms2_sgrid(self): """Curvilinear (structured) grid The rectangular grid case is covered by the classic case """ - pytest.importorskip('cdms2') + pytest.importorskip("cdms2") lonlat = np.mgrid[:3, :4] - lon = DataArray(lonlat[1], dims=['y', 'x'], name='lon') - lat = DataArray(lonlat[0], dims=['y', 'x'], name='lat') - x = DataArray(np.arange(lon.shape[1]), dims=['x'], name='x') - y = DataArray(np.arange(lon.shape[0]), dims=['y'], name='y') - original = DataArray(lonlat.sum(axis=0), dims=['y', 'x'], - coords=OrderedDict(x=x, y=y, lon=lon, lat=lat), - name='sst') + lon = DataArray(lonlat[1], dims=["y", "x"], name="lon") + lat = DataArray(lonlat[0], dims=["y", "x"], name="lat") + x = DataArray(np.arange(lon.shape[1]), dims=["x"], name="x") + y = DataArray(np.arange(lon.shape[0]), dims=["y"], name="y") + original = DataArray( + lonlat.sum(axis=0), + dims=["y", "x"], + coords=OrderedDict(x=x, y=y, lon=lon, lat=lat), + name="sst", + ) actual = original.to_cdms2() assert tuple(actual.getAxisIds()) == original.dims - assert_array_equal(original.coords['lon'], - actual.getLongitude().asma()) - assert_array_equal(original.coords['lat'], - actual.getLatitude().asma()) + assert_array_equal(original.coords["lon"], actual.getLongitude().asma()) + assert_array_equal(original.coords["lat"], actual.getLatitude().asma()) back = from_cdms2(actual) assert original.dims == back.dims assert set(original.coords.keys()) == set(back.coords.keys()) - assert_array_equal(original.coords['lat'], back.coords['lat']) - assert_array_equal(original.coords['lon'], back.coords['lon']) + assert_array_equal(original.coords["lat"], back.coords["lat"]) + assert_array_equal(original.coords["lon"], back.coords["lon"]) def test_to_and_from_cdms2_ugrid(self): """Unstructured grid""" - pytest.importorskip('cdms2') + pytest.importorskip("cdms2") - lon = DataArray(np.random.uniform(size=5), dims=['cell'], name='lon') - lat = DataArray(np.random.uniform(size=5), dims=['cell'], name='lat') - cell = DataArray(np.arange(5), dims=['cell'], name='cell') - original = DataArray(np.arange(5), dims=['cell'], - coords={'lon': lon, 'lat': lat, 'cell': cell}) + lon = DataArray(np.random.uniform(size=5), dims=["cell"], name="lon") + lat = DataArray(np.random.uniform(size=5), dims=["cell"], name="lat") + cell = DataArray(np.arange(5), dims=["cell"], name="cell") + original = DataArray( + np.arange(5), dims=["cell"], coords={"lon": lon, "lat": lat, "cell": cell} + ) actual = original.to_cdms2() assert tuple(actual.getAxisIds()) == original.dims - assert_array_equal(original.coords['lon'], - actual.getLongitude().getValue()) - assert_array_equal(original.coords['lat'], - actual.getLatitude().getValue()) + assert_array_equal(original.coords["lon"], actual.getLongitude().getValue()) + assert_array_equal(original.coords["lat"], actual.getLatitude().getValue()) back = from_cdms2(actual) assert set(original.dims) == set(back.dims) assert set(original.coords.keys()) == set(back.coords.keys()) - assert_array_equal(original.coords['lat'], back.coords['lat']) - assert_array_equal(original.coords['lon'], back.coords['lon']) + assert_array_equal(original.coords["lat"], back.coords["lat"]) + assert_array_equal(original.coords["lon"], back.coords["lon"]) def test_to_dataset_whole(self): - unnamed = DataArray([1, 2], dims='x') - with raises_regex(ValueError, 'unable to convert unnamed'): + unnamed = DataArray([1, 2], dims="x") + with raises_regex(ValueError, "unable to convert unnamed"): unnamed.to_dataset() - actual = unnamed.to_dataset(name='foo') - expected = Dataset({'foo': ('x', [1, 2])}) + actual = unnamed.to_dataset(name="foo") + expected = Dataset({"foo": ("x", [1, 2])}) assert_identical(expected, actual) - named = DataArray([1, 2], dims='x', name='foo') + named = DataArray([1, 2], dims="x", name="foo") actual = named.to_dataset() - expected = Dataset({'foo': ('x', [1, 2])}) + expected = Dataset({"foo": ("x", [1, 2])}) assert_identical(expected, actual) - expected = Dataset({'bar': ('x', [1, 2])}) + expected = Dataset({"bar": ("x", [1, 2])}) with pytest.warns(FutureWarning): - actual = named.to_dataset('bar') + actual = named.to_dataset("bar") assert_identical(expected, actual) def test_to_dataset_split(self): - array = DataArray([1, 2, 3], coords=[('x', list('abc'))], - attrs={'a': 1}) - expected = Dataset(OrderedDict([('a', 1), ('b', 2), ('c', 3)]), - attrs={'a': 1}) - actual = array.to_dataset('x') + array = DataArray([1, 2, 3], coords=[("x", list("abc"))], attrs={"a": 1}) + expected = Dataset(OrderedDict([("a", 1), ("b", 2), ("c", 3)]), attrs={"a": 1}) + actual = array.to_dataset("x") assert_identical(expected, actual) with pytest.raises(TypeError): - array.to_dataset('x', name='foo') + array.to_dataset("x", name="foo") - roundtripped = actual.to_array(dim='x') + roundtripped = actual.to_array(dim="x") assert_identical(array, roundtripped) - array = DataArray([1, 2, 3], dims='x') + array = DataArray([1, 2, 3], dims="x") expected = Dataset(OrderedDict([(0, 1), (1, 2), (2, 3)])) - actual = array.to_dataset('x') + actual = array.to_dataset("x") assert_identical(expected, actual) def test_to_dataset_retains_keys(self): # use dates as convenient non-str objects. Not a specific date test import datetime + dates = [datetime.date(2000, 1, d) for d in range(1, 4)] - array = DataArray([1, 2, 3], coords=[('x', dates)], - attrs={'a': 1}) + array = DataArray([1, 2, 3], coords=[("x", dates)], attrs={"a": 1}) # convert to dateset and back again - result = array.to_dataset('x').to_array(dim='x') + result = array.to_dataset("x").to_array(dim="x") assert_equal(array, result) def test__title_for_slice(self): - array = DataArray(np.ones((4, 3, 2)), dims=['a', 'b', 'c'], - coords={'a': range(4), 'b': range(3), 'c': range(2)}) - assert '' == array._title_for_slice() - assert 'c = 0' == array.isel(c=0)._title_for_slice() + array = DataArray( + np.ones((4, 3, 2)), + dims=["a", "b", "c"], + coords={"a": range(4), "b": range(3), "c": range(2)}, + ) + assert "" == array._title_for_slice() + assert "c = 0" == array.isel(c=0)._title_for_slice() title = array.isel(b=1, c=0)._title_for_slice() - assert 'b = 1, c = 0' == title or 'c = 0, b = 1' == title + assert "b = 1, c = 0" == title or "c = 0, b = 1" == title - a2 = DataArray(np.ones((4, 1)), dims=['a', 'b']) - assert '' == a2._title_for_slice() + a2 = DataArray(np.ones((4, 1)), dims=["a", "b"]) + assert "" == a2._title_for_slice() def test__title_for_slice_truncate(self): array = DataArray(np.ones(4)) - array.coords['a'] = 'a' * 100 - array.coords['b'] = 'b' * 100 + array.coords["a"] = "a" * 100 + array.coords["b"] = "b" * 100 nchar = 80 title = array._title_for_slice(truncate=nchar) assert nchar == len(title) - assert title.endswith('...') + assert title.endswith("...") def test_dataarray_diff_n1(self): - da = DataArray(np.random.randn(3, 4), dims=['x', 'y']) - actual = da.diff('y') - expected = DataArray(np.diff(da.values, axis=1), dims=['x', 'y']) + da = DataArray(np.random.randn(3, 4), dims=["x", "y"]) + actual = da.diff("y") + expected = DataArray(np.diff(da.values, axis=1), dims=["x", "y"]) assert_equal(expected, actual) def test_coordinate_diff(self): # regression test for GH634 - arr = DataArray(range(0, 20, 2), dims=['lon'], coords=[range(10)]) - lon = arr.coords['lon'] - expected = DataArray([1] * 9, dims=['lon'], coords=[range(1, 10)], - name='lon') - actual = lon.diff('lon') + arr = DataArray(range(0, 20, 2), dims=["lon"], coords=[range(10)]) + lon = arr.coords["lon"] + expected = DataArray([1] * 9, dims=["lon"], coords=[range(1, 10)], name="lon") + actual = lon.diff("lon") assert_equal(expected, actual) - @pytest.mark.parametrize('offset', [-5, 0, 1, 2]) - @pytest.mark.parametrize('fill_value, dtype', - [(2, int), (dtypes.NA, float)]) + @pytest.mark.parametrize("offset", [-5, 0, 1, 2]) + @pytest.mark.parametrize("fill_value, dtype", [(2, int), (dtypes.NA, float)]) def test_shift(self, offset, fill_value, dtype): - arr = DataArray([1, 2, 3], dims='x') + arr = DataArray([1, 2, 3], dims="x") actual = arr.shift(x=1, fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = DataArray([fill_value, 1, 2], dims='x') + expected = DataArray([fill_value, 1, 2], dims="x") assert_identical(expected, actual) assert actual.dtype == dtype - arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])]) + arr = DataArray([1, 2, 3], [("x", ["a", "b", "c"])]) expected = DataArray(arr.to_pandas().shift(offset)) actual = arr.shift(x=offset) assert_identical(expected, actual) def test_roll_coords(self): - arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") actual = arr.roll(x=1, roll_coords=True) - expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) + expected = DataArray([3, 1, 2], coords=[("x", [2, 0, 1])]) assert_identical(expected, actual) def test_roll_no_coords(self): - arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") actual = arr.roll(x=1, roll_coords=False) - expected = DataArray([3, 1, 2], coords=[('x', [0, 1, 2])]) + expected = DataArray([3, 1, 2], coords=[("x", [0, 1, 2])]) assert_identical(expected, actual) def test_roll_coords_none(self): - arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") with pytest.warns(FutureWarning): actual = arr.roll(x=1, roll_coords=None) - expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) + expected = DataArray([3, 1, 2], coords=[("x", [2, 0, 1])]) assert_identical(expected, actual) def test_copy_with_data(self): - orig = DataArray(np.random.random(size=(2, 2)), - dims=('x', 'y'), - attrs={'attr1': 'value1'}, - coords={'x': [4, 3]}, - name='helloworld') + orig = DataArray( + np.random.random(size=(2, 2)), + dims=("x", "y"), + attrs={"attr1": "value1"}, + coords={"x": [4, 3]}, + name="helloworld", + ) new_data = np.arange(4).reshape(2, 2) actual = orig.copy(data=new_data) expected = orig.copy() @@ -3455,30 +3736,47 @@ def test_copy_with_data(self): assert_identical(expected, actual) @pytest.mark.xfail(raises=AssertionError) - @pytest.mark.parametrize('deep, expected_orig', [ - [True, - xr.DataArray(xr.IndexVariable('a', np.array([1, 2])), - coords={'a': [1, 2]}, dims=['a'])], - [False, - xr.DataArray(xr.IndexVariable('a', np.array([999, 2])), - coords={'a': [999, 2]}, dims=['a'])]]) + @pytest.mark.parametrize( + "deep, expected_orig", + [ + [ + True, + xr.DataArray( + xr.IndexVariable("a", np.array([1, 2])), + coords={"a": [1, 2]}, + dims=["a"], + ), + ], + [ + False, + xr.DataArray( + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ), + ], + ], + ) def test_copy_coords(self, deep, expected_orig): """The test fails for the shallow copy, and apparently only on Windows for some reason. In windows coords seem to be immutable unless it's one dataarray deep copied from another.""" da = xr.DataArray( np.ones([2, 2, 2]), - coords={'a': [1, 2], 'b': ['x', 'y'], 'c': [0, 1]}, - dims=['a', 'b', 'c']) + coords={"a": [1, 2], "b": ["x", "y"], "c": [0, 1]}, + dims=["a", "b", "c"], + ) da_cp = da.copy(deep) - da_cp['a'].data[0] = 999 + da_cp["a"].data[0] = 999 expected_cp = xr.DataArray( - xr.IndexVariable('a', np.array([999, 2])), - coords={'a': [999, 2]}, dims=['a']) - assert_identical(da_cp['a'], expected_cp) + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ) + assert_identical(da_cp["a"], expected_cp) - assert_identical(da['a'], expected_orig) + assert_identical(da["a"], expected_orig) def test_real_and_imag(self): array = DataArray(1 + 2j) @@ -3486,21 +3784,23 @@ def test_real_and_imag(self): assert_identical(array.imag, DataArray(2)) def test_setattr_raises(self): - array = DataArray(0, coords={'scalar': 1}, attrs={'foo': 'bar'}) - with raises_regex(AttributeError, 'cannot set attr'): + array = DataArray(0, coords={"scalar": 1}, attrs={"foo": "bar"}) + with raises_regex(AttributeError, "cannot set attr"): array.scalar = 2 - with raises_regex(AttributeError, 'cannot set attr'): + with raises_regex(AttributeError, "cannot set attr"): array.foo = 2 - with raises_regex(AttributeError, 'cannot set attr'): + with raises_regex(AttributeError, "cannot set attr"): array.other = 2 def test_full_like(self): # For more thorough tests, see test_variable.py - da = DataArray(np.random.random(size=(2, 2)), - dims=('x', 'y'), - attrs={'attr1': 'value1'}, - coords={'x': [4, 3]}, - name='helloworld') + da = DataArray( + np.random.random(size=(2, 2)), + dims=("x", "y"), + attrs={"attr1": "value1"}, + coords={"x": [4, 3]}, + name="helloworld", + ) actual = full_like(da, 2) expect = da.copy(deep=True) @@ -3518,35 +3818,34 @@ def test_dot(self): y = np.linspace(-3, 3, 5) z = range(4) da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) - da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z']) + da = DataArray(da_vals, coords=[x, y, z], dims=["x", "y", "z"]) dm_vals = range(4) - dm = DataArray(dm_vals, coords=[z], dims=['z']) + dm = DataArray(dm_vals, coords=[z], dims=["z"]) # nd dot 1d actual = da.dot(dm) expected_vals = np.tensordot(da_vals, dm_vals, [2, 0]) - expected = DataArray(expected_vals, coords=[x, y], dims=['x', 'y']) + expected = DataArray(expected_vals, coords=[x, y], dims=["x", "y"]) assert_equal(expected, actual) # all shared dims actual = da.dot(da) - expected_vals = np.tensordot(da_vals, da_vals, - axes=([0, 1, 2], [0, 1, 2])) + expected_vals = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2])) expected = DataArray(expected_vals) assert_equal(expected, actual) # multiple shared dims dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) j = np.linspace(-3, 3, 20) - dm = DataArray(dm_vals, coords=[j, y, z], dims=['j', 'y', 'z']) + dm = DataArray(dm_vals, coords=[j, y, z], dims=["j", "y", "z"]) actual = da.dot(dm) expected_vals = np.tensordot(da_vals, dm_vals, axes=([1, 2], [1, 2])) - expected = DataArray(expected_vals, coords=[x, j], dims=['x', 'j']) + expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"]) assert_equal(expected, actual) with pytest.raises(NotImplementedError): - da.dot(dm.to_dataset(name='dm')) + da.dot(dm.to_dataset(name="dm")) with pytest.raises(TypeError): da.dot(dm.values) @@ -3557,58 +3856,66 @@ def test_matmul(self): y = np.linspace(-3, 3, 5) z = range(4) da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) - da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z']) + da = DataArray(da_vals, coords=[x, y, z], dims=["x", "y", "z"]) result = da @ da expected = da.dot(da) assert_identical(result, expected) def test_binary_op_join_setting(self): - dim = 'x' + dim = "x" align_type = "outer" coords_l, coords_r = [0, 1, 2], [1, 2, 3] missing_3 = xr.DataArray(coords_l, [(dim, coords_l)]) missing_0 = xr.DataArray(coords_r, [(dim, coords_r)]) with xr.set_options(arithmetic_join=align_type): actual = missing_0 + missing_3 - missing_0_aligned, missing_3_aligned = xr.align(missing_0, - missing_3, - join=align_type) + missing_0_aligned, missing_3_aligned = xr.align( + missing_0, missing_3, join=align_type + ) expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])]) assert_equal(actual, expected) def test_combine_first(self): - ar0 = DataArray([[0, 0], [0, 0]], [('x', ['a', 'b']), ('y', [-1, 0])]) - ar1 = DataArray([[1, 1], [1, 1]], [('x', ['b', 'c']), ('y', [0, 1])]) - ar2 = DataArray([2], [('x', ['d'])]) + ar0 = DataArray([[0, 0], [0, 0]], [("x", ["a", "b"]), ("y", [-1, 0])]) + ar1 = DataArray([[1, 1], [1, 1]], [("x", ["b", "c"]), ("y", [0, 1])]) + ar2 = DataArray([2], [("x", ["d"])]) actual = ar0.combine_first(ar1) - expected = DataArray([[0, 0, np.nan], [0, 0, 1], [np.nan, 1, 1]], - [('x', ['a', 'b', 'c']), ('y', [-1, 0, 1])]) + expected = DataArray( + [[0, 0, np.nan], [0, 0, 1], [np.nan, 1, 1]], + [("x", ["a", "b", "c"]), ("y", [-1, 0, 1])], + ) assert_equal(actual, expected) actual = ar1.combine_first(ar0) - expected = DataArray([[0, 0, np.nan], [0, 1, 1], [np.nan, 1, 1]], - [('x', ['a', 'b', 'c']), ('y', [-1, 0, 1])]) + expected = DataArray( + [[0, 0, np.nan], [0, 1, 1], [np.nan, 1, 1]], + [("x", ["a", "b", "c"]), ("y", [-1, 0, 1])], + ) assert_equal(actual, expected) actual = ar0.combine_first(ar2) - expected = DataArray([[0, 0], [0, 0], [2, 2]], - [('x', ['a', 'b', 'd']), ('y', [-1, 0])]) + expected = DataArray( + [[0, 0], [0, 0], [2, 2]], [("x", ["a", "b", "d"]), ("y", [-1, 0])] + ) assert_equal(actual, expected) def test_sortby(self): - da = DataArray([[1, 2], [3, 4], [5, 6]], - [('x', ['c', 'b', 'a']), ('y', [1, 0])]) + da = DataArray( + [[1, 2], [3, 4], [5, 6]], [("x", ["c", "b", "a"]), ("y", [1, 0])] + ) - sorted1d = DataArray([[5, 6], [3, 4], [1, 2]], - [('x', ['a', 'b', 'c']), ('y', [1, 0])]) + sorted1d = DataArray( + [[5, 6], [3, 4], [1, 2]], [("x", ["a", "b", "c"]), ("y", [1, 0])] + ) - sorted2d = DataArray([[6, 5], [4, 3], [2, 1]], - [('x', ['a', 'b', 'c']), ('y', [0, 1])]) + sorted2d = DataArray( + [[6, 5], [4, 3], [2, 1]], [("x", ["a", "b", "c"]), ("y", [0, 1])] + ) expected = sorted1d - dax = DataArray([100, 99, 98], [('x', ['c', 'b', 'a'])]) + dax = DataArray([100, 99, 98], [("x", ["c", "b", "a"])]) actual = da.sortby(dax) assert_equal(actual, expected) @@ -3617,23 +3924,23 @@ def test_sortby(self): assert_equal(actual, da) # test alignment (fills in nan for 'c') - dax_short = DataArray([98, 97], [('x', ['b', 'a'])]) + dax_short = DataArray([98, 97], [("x", ["b", "a"])]) actual = da.sortby(dax_short) assert_equal(actual, expected) # test multi-dim sort by 1D dataarray values expected = sorted2d - dax = DataArray([100, 99, 98], [('x', ['c', 'b', 'a'])]) - day = DataArray([90, 80], [('y', [1, 0])]) + dax = DataArray([100, 99, 98], [("x", ["c", "b", "a"])]) + day = DataArray([90, 80], [("y", [1, 0])]) actual = da.sortby([day, dax]) assert_equal(actual, expected) expected = sorted1d - actual = da.sortby('x') + actual = da.sortby("x") assert_equal(actual, expected) expected = sorted2d - actual = da.sortby(['x', 'y']) + actual = da.sortby(["x", "y"]) assert_equal(actual, expected) @requires_bottleneck @@ -3642,162 +3949,157 @@ def test_rank(self): ar = DataArray([[3, 4, np.nan, 1]]) expect_0 = DataArray([[1, 1, np.nan, 1]]) expect_1 = DataArray([[2, 3, np.nan, 1]]) - assert_equal(ar.rank('dim_0'), expect_0) - assert_equal(ar.rank('dim_1'), expect_1) + assert_equal(ar.rank("dim_0"), expect_0) + assert_equal(ar.rank("dim_1"), expect_1) # int x = DataArray([3, 2, 1]) - assert_equal(x.rank('dim_0'), x) + assert_equal(x.rank("dim_0"), x) # str - y = DataArray(['c', 'b', 'a']) - assert_equal(y.rank('dim_0'), x) + y = DataArray(["c", "b", "a"]) + assert_equal(y.rank("dim_0"), x) - x = DataArray([3.0, 1.0, np.nan, 2.0, 4.0], dims=('z',)) - y = DataArray([0.75, 0.25, np.nan, 0.5, 1.0], dims=('z',)) - assert_equal(y.rank('z', pct=True), y) + x = DataArray([3.0, 1.0, np.nan, 2.0, 4.0], dims=("z",)) + y = DataArray([0.75, 0.25, np.nan, 0.5, 1.0], dims=("z",)) + assert_equal(y.rank("z", pct=True), y) @pytest.fixture(params=[1]) def da(request): if request.param == 1: - times = pd.date_range('2000-01-01', freq='1D', periods=21) + times = pd.date_range("2000-01-01", freq="1D", periods=21) values = np.random.random((3, 21, 4)) - da = DataArray(values, dims=('a', 'time', 'x')) - da['time'] = times + da = DataArray(values, dims=("a", "time", "x")) + da["time"] = times return da if request.param == 2: - return DataArray( - [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], - dims='time') + return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") - if request.param == 'repeating_ints': + if request.param == "repeating_ints": return DataArray( np.tile(np.arange(12), 5).reshape(5, 4, 3), - coords={'x': list('abc'), - 'y': list('defg')}, - dims=list('zyx') + coords={"x": list("abc"), "y": list("defg")}, + dims=list("zyx"), ) @pytest.fixture def da_dask(seed=123): - pytest.importorskip('dask.array') + pytest.importorskip("dask.array") rs = np.random.RandomState(seed) - times = pd.date_range('2000-01-01', freq='1D', periods=21) + times = pd.date_range("2000-01-01", freq="1D", periods=21) values = rs.normal(size=(1, 21, 1)) - da = DataArray(values, dims=('a', 'time', 'x')).chunk({'time': 7}) - da['time'] = times + da = DataArray(values, dims=("a", "time", "x")).chunk({"time": 7}) + da["time"] = times return da -@pytest.mark.parametrize('da', ('repeating_ints', ), indirect=True) +@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True) def test_isin(da): expected = DataArray( np.asarray([[0, 0, 0], [1, 0, 0]]), - dims=list('yx'), - coords={'x': list('abc'), - 'y': list('de')}, - ).astype('bool') + dims=list("yx"), + coords={"x": list("abc"), "y": list("de")}, + ).astype("bool") - result = da.isin([3]).sel(y=list('de'), z=0) + result = da.isin([3]).sel(y=list("de"), z=0) assert_equal(result, expected) expected = DataArray( np.asarray([[0, 0, 1], [1, 0, 0]]), - dims=list('yx'), - coords={'x': list('abc'), - 'y': list('de')}, - ).astype('bool') - result = da.isin([2, 3]).sel(y=list('de'), z=0) + dims=list("yx"), + coords={"x": list("abc"), "y": list("de")}, + ).astype("bool") + result = da.isin([2, 3]).sel(y=list("de"), z=0) assert_equal(result, expected) -@pytest.mark.parametrize('da', (1, 2), indirect=True) +@pytest.mark.parametrize("da", (1, 2), indirect=True) def test_rolling_iter(da): rolling_obj = da.rolling(time=7) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Mean of empty slice') + warnings.filterwarnings("ignore", "Mean of empty slice") rolling_obj_mean = rolling_obj.mean() - assert len(rolling_obj.window_labels) == len(da['time']) - assert_identical(rolling_obj.window_labels, da['time']) + assert len(rolling_obj.window_labels) == len(da["time"]) + assert_identical(rolling_obj.window_labels, da["time"]) for i, (label, window_da) in enumerate(rolling_obj): - assert label == da['time'].isel(time=i) + assert label == da["time"].isel(time=i) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Mean of empty slice') + warnings.filterwarnings("ignore", "Mean of empty slice") actual = rolling_obj_mean.isel(time=i) - expected = window_da.mean('time') + expected = window_da.mean("time") # TODO add assert_allclose_with_nan, which compares nan position # as well as the closeness of the values. assert_array_equal(actual.isnull(), expected.isnull()) if (~actual.isnull()).sum() > 0: - np.allclose(actual.values[actual.values.nonzero()], - expected.values[expected.values.nonzero()]) + np.allclose( + actual.values[actual.values.nonzero()], + expected.values[expected.values.nonzero()], + ) def test_rolling_doc(da): rolling_obj = da.rolling(time=7) # argument substitution worked - assert '`mean`' in rolling_obj.mean.__doc__ + assert "`mean`" in rolling_obj.mean.__doc__ def test_rolling_properties(da): rolling_obj = da.rolling(time=4) - assert rolling_obj.obj.get_axis_num('time') == 1 + assert rolling_obj.obj.get_axis_num("time") == 1 # catching invalid args - with pytest.raises(ValueError, match='exactly one dim/window should'): + with pytest.raises(ValueError, match="exactly one dim/window should"): da.rolling(time=7, x=2) - with pytest.raises(ValueError, match='window must be > 0'): + with pytest.raises(ValueError, match="window must be > 0"): da.rolling(time=-2) - with pytest.raises( - ValueError, match='min_periods must be greater than zero' - ): + with pytest.raises(ValueError, match="min_periods must be greater than zero"): da.rolling(time=2, min_periods=0) -@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'min', 'max', - 'median')) -@pytest.mark.parametrize('center', (True, False, None)) -@pytest.mark.parametrize('min_periods', (1, None)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median")) +@pytest.mark.parametrize("center", (True, False, None)) +@pytest.mark.parametrize("min_periods", (1, None)) def test_rolling_wrapped_bottleneck(da, name, center, min_periods): - bn = pytest.importorskip('bottleneck', minversion="1.1") + bn = pytest.importorskip("bottleneck", minversion="1.1") # Test all bottleneck functions rolling_obj = da.rolling(time=7, min_periods=min_periods) - func_name = 'move_{}'.format(name) + func_name = "move_{}".format(name) actual = getattr(rolling_obj, name)() - expected = getattr(bn, func_name)(da.values, window=7, axis=1, - min_count=min_periods) + expected = getattr(bn, func_name)( + da.values, window=7, axis=1, min_count=min_periods + ) assert_array_equal(actual.values, expected) # Test center rolling_obj = da.rolling(time=7, center=center) - actual = getattr(rolling_obj, name)()['time'] - assert_equal(actual, da['time']) + actual = getattr(rolling_obj, name)()["time"] + assert_equal(actual, da["time"]) -@pytest.mark.parametrize('name', ('mean', 'count')) -@pytest.mark.parametrize('center', (True, False, None)) -@pytest.mark.parametrize('min_periods', (1, None)) -@pytest.mark.parametrize('window', (7, 8)) +@pytest.mark.parametrize("name", ("mean", "count")) +@pytest.mark.parametrize("center", (True, False, None)) +@pytest.mark.parametrize("min_periods", (1, None)) +@pytest.mark.parametrize("window", (7, 8)) def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): - pytest.importorskip('dask.array') + pytest.importorskip("dask.array") # dask version - rolling_obj = da_dask.rolling(time=window, min_periods=min_periods, - center=center) + rolling_obj = da_dask.rolling(time=window, min_periods=min_periods, center=center) actual = getattr(rolling_obj, name)().load() # numpy version - rolling_obj = da_dask.load().rolling(time=window, min_periods=min_periods, - center=center) + rolling_obj = da_dask.load().rolling( + time=window, min_periods=min_periods, center=center + ) expected = getattr(rolling_obj, name)() # using all-close because rolling over ghost cells introduces some @@ -3805,27 +4107,29 @@ def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): assert_allclose(actual, expected) # with zero chunked array GH:2113 - rolling_obj = da_dask.chunk().rolling(time=window, min_periods=min_periods, - center=center) + rolling_obj = da_dask.chunk().rolling( + time=window, min_periods=min_periods, center=center + ) actual = getattr(rolling_obj, name)().load() assert_allclose(actual, expected) -@pytest.mark.parametrize('center', (True, None)) +@pytest.mark.parametrize("center", (True, None)) def test_rolling_wrapped_dask_nochunk(center): # GH:2113 - pytest.importorskip('dask.array') + pytest.importorskip("dask.array") - da_day_clim = xr.DataArray(np.arange(1, 367), - coords=[np.arange(1, 367)], dims='dayofyear') + da_day_clim = xr.DataArray( + np.arange(1, 367), coords=[np.arange(1, 367)], dims="dayofyear" + ) expected = da_day_clim.rolling(dayofyear=31, center=center).mean() actual = da_day_clim.chunk().rolling(dayofyear=31, center=center).mean() assert_allclose(actual, expected) -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) def test_rolling_pandas_compat(center, window, min_periods): s = pd.Series(np.arange(10)) da = DataArray.from_series(s) @@ -3833,21 +4137,20 @@ def test_rolling_pandas_compat(center, window, min_periods): if min_periods is not None and window < min_periods: min_periods = window - s_rolling = s.rolling(window, center=center, - min_periods=min_periods).mean() - da_rolling = da.rolling(index=window, center=center, - min_periods=min_periods).mean() - da_rolling_np = da.rolling(index=window, center=center, - min_periods=min_periods).reduce(np.nanmean) + s_rolling = s.rolling(window, center=center, min_periods=min_periods).mean() + da_rolling = da.rolling(index=window, center=center, min_periods=min_periods).mean() + da_rolling_np = da.rolling( + index=window, center=center, min_periods=min_periods + ).reduce(np.nanmean) np.testing.assert_allclose(s_rolling.values, da_rolling.values) - np.testing.assert_allclose(s_rolling.index, da_rolling['index']) + np.testing.assert_allclose(s_rolling.index, da_rolling["index"]) np.testing.assert_allclose(s_rolling.values, da_rolling_np.values) - np.testing.assert_allclose(s_rolling.index, da_rolling_np['index']) + np.testing.assert_allclose(s_rolling.index, da_rolling_np["index"]) -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) def test_rolling_construct(center, window): s = pd.Series(np.arange(10)) da = DataArray.from_series(s) @@ -3855,28 +4158,28 @@ def test_rolling_construct(center, window): s_rolling = s.rolling(window, center=center, min_periods=1).mean() da_rolling = da.rolling(index=window, center=center, min_periods=1) - da_rolling_mean = da_rolling.construct('window').mean('window') + da_rolling_mean = da_rolling.construct("window").mean("window") np.testing.assert_allclose(s_rolling.values, da_rolling_mean.values) - np.testing.assert_allclose(s_rolling.index, da_rolling_mean['index']) + np.testing.assert_allclose(s_rolling.index, da_rolling_mean["index"]) # with stride - da_rolling_mean = da_rolling.construct('window', - stride=2).mean('window') + da_rolling_mean = da_rolling.construct("window", stride=2).mean("window") np.testing.assert_allclose(s_rolling.values[::2], da_rolling_mean.values) - np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean['index']) + np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean["index"]) # with fill_value - da_rolling_mean = da_rolling.construct( - 'window', stride=2, fill_value=0.0).mean('window') + da_rolling_mean = da_rolling.construct("window", stride=2, fill_value=0.0).mean( + "window" + ) assert da_rolling_mean.isnull().sum() == 0 assert (da_rolling_mean == 0.0).sum() >= 0 -@pytest.mark.parametrize('da', (1, 2), indirect=True) -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) -@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'max')) +@pytest.mark.parametrize("da", (1, 2), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) def test_rolling_reduce(da, center, min_periods, window, name): if min_periods is not None and window < min_periods: @@ -3886,33 +4189,32 @@ def test_rolling_reduce(da, center, min_periods, window, name): # this causes all nan slices window = 2 - rolling_obj = da.rolling(time=window, center=center, - min_periods=min_periods) + rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar # behavior as bottleneck - actual = rolling_obj.reduce(getattr(np, 'nan%s' % name)) + actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) assert actual.dims == expected.dims @requires_np113 -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) -@pytest.mark.parametrize('name', ('sum', 'max')) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "max")) def test_rolling_reduce_nonnumeric(center, min_periods, window, name): - da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], - dims='time').isnull() + da = DataArray( + [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" + ).isnull() if min_periods is not None and window < min_periods: min_periods = window - rolling_obj = da.rolling(time=window, center=center, - min_periods=min_periods) + rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar behavior as bottleneck - actual = rolling_obj.reduce(getattr(np, 'nan%s' % name)) + actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) assert actual.dims == expected.dims @@ -3920,25 +4222,39 @@ def test_rolling_reduce_nonnumeric(center, min_periods, window, name): def test_rolling_count_correct(): - da = DataArray( - [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims='time') + da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") - kwargs = [{'time': 11, 'min_periods': 1}, - {'time': 11, 'min_periods': None}, - {'time': 7, 'min_periods': 2}] - expecteds = [DataArray( - [1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims='time'), - DataArray( - [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan], dims='time'), + kwargs = [ + {"time": 11, "min_periods": 1}, + {"time": 11, "min_periods": None}, + {"time": 7, "min_periods": 2}, + ] + expecteds = [ + DataArray([1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims="time"), DataArray( - [np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims='time')] + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + ], + dims="time", + ), + DataArray([np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims="time"), + ] for kwarg, expected in zip(kwargs, expecteds): result = da.rolling(**kwarg).count() assert_equal(result, expected) - result = da.to_dataset(name='var1').rolling(**kwarg).count()['var1'] + result = da.to_dataset(name="var1").rolling(**kwarg).count()["var1"] assert_equal(result, expected) @@ -3949,12 +4265,12 @@ def test_raise_no_warning_for_nan_in_binary_ops(): def test_name_in_masking(): - name = 'RingoStarr' - da = xr.DataArray(range(10), coords=[('x', range(10))], name=name) + name = "RingoStarr" + da = xr.DataArray(range(10), coords=[("x", range(10))], name=name) assert da.where(da > 5).name == name - assert da.where((da > 5).rename('YokoOno')).name == name + assert da.where((da > 5).rename("YokoOno")).name == name assert da.where(da > 5, drop=True).name == name - assert da.where((da > 5).rename('YokoOno'), drop=True).name == name + assert da.where((da > 5).rename("YokoOno"), drop=True).name == name class TestIrisConversion: @@ -3965,56 +4281,68 @@ def test_to_and_from_iris(self): # to iris coord_dict = OrderedDict() - coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) - coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) - coord_dict['height'] = 10 - coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) - coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) - - original = DataArray(np.arange(6, dtype='float').reshape(2, 3), - coord_dict, name='Temperature', - attrs={'baz': 123, 'units': 'Kelvin', - 'standard_name': 'fire_temperature', - 'long_name': 'Fire Temperature'}, - dims=('distance', 'time')) + coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"}) + coord_dict["time"] = ("time", pd.date_range("2000-01-01", periods=3)) + coord_dict["height"] = 10 + coord_dict["distance2"] = ("distance", [0, 1], {"foo": "bar"}) + coord_dict["time2"] = (("distance", "time"), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray( + np.arange(6, dtype="float").reshape(2, 3), + coord_dict, + name="Temperature", + attrs={ + "baz": 123, + "units": "Kelvin", + "standard_name": "fire_temperature", + "long_name": "Fire Temperature", + }, + dims=("distance", "time"), + ) # Set a bad value to test the masking logic original.data[0, 2] = np.NaN - original.attrs['cell_methods'] = \ - 'height: mean (comment: A cell method)' + original.attrs["cell_methods"] = "height: mean (comment: A cell method)" actual = original.to_iris() assert_array_equal(actual.data, original.data) assert actual.var_name == original.name assert tuple(d.var_name for d in actual.dim_coords) == original.dims - assert (actual.cell_methods == (iris.coords.CellMethod( - method='mean', - coords=('height', ), - intervals=(), - comments=('A cell method', )), )) + assert actual.cell_methods == ( + iris.coords.CellMethod( + method="mean", + coords=("height",), + intervals=(), + comments=("A cell method",), + ), + ) for coord, orginal_key in zip((actual.coords()), original.coords): original_coord = original.coords[orginal_key] assert coord.var_name == original_coord.name assert_array_equal( - coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) - == original.get_axis_num( - original.coords[coord.var_name].dims)) - - assert (actual.coord('distance2').attributes['foo'] - == original.coords['distance2'].attrs['foo']) - assert (actual.coord('distance').units - == cf_units.Unit(original.coords['distance'].units)) - assert actual.attributes['baz'] == original.attrs['baz'] - assert actual.standard_name == original.attrs['standard_name'] + coord.points, CFDatetimeCoder().encode(original_coord).values + ) + assert actual.coord_dims(coord) == original.get_axis_num( + original.coords[coord.var_name].dims + ) + + assert ( + actual.coord("distance2").attributes["foo"] + == original.coords["distance2"].attrs["foo"] + ) + assert actual.coord("distance").units == cf_units.Unit( + original.coords["distance"].units + ) + assert actual.attributes["baz"] == original.attrs["baz"] + assert actual.standard_name == original.attrs["standard_name"] roundtripped = DataArray.from_iris(actual) assert_identical(original, roundtripped) - actual.remove_coord('time') + actual.remove_coord("time") auto_time_dimension = DataArray.from_iris(actual) - assert auto_time_dimension.dims == ('distance', 'dim_1') + assert auto_time_dimension.dims == ("distance", "dim_1") @requires_iris @requires_dask @@ -4024,109 +4352,142 @@ def test_to_and_from_iris_dask(self): import cf_units # iris requirement coord_dict = OrderedDict() - coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) - coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) - coord_dict['height'] = 10 - coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) - coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) + coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"}) + coord_dict["time"] = ("time", pd.date_range("2000-01-01", periods=3)) + coord_dict["height"] = 10 + coord_dict["distance2"] = ("distance", [0, 1], {"foo": "bar"}) + coord_dict["time2"] = (("distance", "time"), [[0, 1, 2], [2, 3, 4]]) original = DataArray( - da.from_array(np.arange(-1, 5, dtype='float').reshape(2, 3), 3), + da.from_array(np.arange(-1, 5, dtype="float").reshape(2, 3), 3), coord_dict, - name='Temperature', - attrs=dict(baz=123, units='Kelvin', - standard_name='fire_temperature', - long_name='Fire Temperature'), - dims=('distance', 'time')) + name="Temperature", + attrs=dict( + baz=123, + units="Kelvin", + standard_name="fire_temperature", + long_name="Fire Temperature", + ), + dims=("distance", "time"), + ) # Set a bad value to test the masking logic original.data = da.ma.masked_less(original.data, 0) - original.attrs['cell_methods'] = \ - 'height: mean (comment: A cell method)' + original.attrs["cell_methods"] = "height: mean (comment: A cell method)" actual = original.to_iris() # Be careful not to trigger the loading of the iris data - actual_data = actual.core_data() if \ - hasattr(actual, 'core_data') else actual.data + actual_data = ( + actual.core_data() if hasattr(actual, "core_data") else actual.data + ) assert_array_equal(actual_data, original.data) assert actual.var_name == original.name assert tuple(d.var_name for d in actual.dim_coords) == original.dims - assert (actual.cell_methods == (iris.coords.CellMethod( - method='mean', - coords=('height', ), - intervals=(), - comments=('A cell method', )), )) + assert actual.cell_methods == ( + iris.coords.CellMethod( + method="mean", + coords=("height",), + intervals=(), + comments=("A cell method",), + ), + ) for coord, orginal_key in zip((actual.coords()), original.coords): original_coord = original.coords[orginal_key] assert coord.var_name == original_coord.name assert_array_equal( - coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) - == original.get_axis_num( - original.coords[coord.var_name].dims)) - - assert (actual.coord('distance2').attributes['foo'] == original.coords[ - 'distance2'].attrs['foo']) - assert (actual.coord('distance').units - == cf_units.Unit(original.coords['distance'].units)) - assert actual.attributes['baz'] == original.attrs['baz'] - assert actual.standard_name == original.attrs['standard_name'] + coord.points, CFDatetimeCoder().encode(original_coord).values + ) + assert actual.coord_dims(coord) == original.get_axis_num( + original.coords[coord.var_name].dims + ) + + assert ( + actual.coord("distance2").attributes["foo"] + == original.coords["distance2"].attrs["foo"] + ) + assert actual.coord("distance").units == cf_units.Unit( + original.coords["distance"].units + ) + assert actual.attributes["baz"] == original.attrs["baz"] + assert actual.standard_name == original.attrs["standard_name"] roundtripped = DataArray.from_iris(actual) assert_identical(original, roundtripped) # If the Iris version supports it then we should have a dask array # at each stage of the conversion - if hasattr(actual, 'core_data'): + if hasattr(actual, "core_data"): assert isinstance(original.data, type(actual.core_data())) assert isinstance(original.data, type(roundtripped.data)) - actual.remove_coord('time') + actual.remove_coord("time") auto_time_dimension = DataArray.from_iris(actual) - assert auto_time_dimension.dims == ('distance', 'dim_1') + assert auto_time_dimension.dims == ("distance", "dim_1") @requires_iris - @pytest.mark.parametrize('var_name, std_name, long_name, name, attrs', [ - ('var_name', 'height', 'Height', - 'var_name', {'standard_name': 'height', 'long_name': 'Height'}), - (None, 'height', 'Height', - 'height', {'standard_name': 'height', 'long_name': 'Height'}), - (None, None, 'Height', - 'Height', {'long_name': 'Height'}), - (None, None, None, - None, {}), - ]) - def test_da_name_from_cube(self, std_name, long_name, var_name, name, - attrs): + @pytest.mark.parametrize( + "var_name, std_name, long_name, name, attrs", + [ + ( + "var_name", + "height", + "Height", + "var_name", + {"standard_name": "height", "long_name": "Height"}, + ), + ( + None, + "height", + "Height", + "height", + {"standard_name": "height", "long_name": "Height"}, + ), + (None, None, "Height", "Height", {"long_name": "Height"}), + (None, None, None, None, {}), + ], + ) + def test_da_name_from_cube(self, std_name, long_name, var_name, name, attrs): from iris.cube import Cube data = [] - cube = Cube(data, var_name=var_name, standard_name=std_name, - long_name=long_name) + cube = Cube( + data, var_name=var_name, standard_name=std_name, long_name=long_name + ) result = xr.DataArray.from_iris(cube) expected = xr.DataArray(data, name=name, attrs=attrs) xr.testing.assert_identical(result, expected) @requires_iris - @pytest.mark.parametrize('var_name, std_name, long_name, name, attrs', [ - ('var_name', 'height', 'Height', - 'var_name', {'standard_name': 'height', 'long_name': 'Height'}), - (None, 'height', 'Height', - 'height', {'standard_name': 'height', 'long_name': 'Height'}), - (None, None, 'Height', - 'Height', {'long_name': 'Height'}), - (None, None, None, - 'unknown', {}), - ]) - def test_da_coord_name_from_cube(self, std_name, long_name, var_name, - name, attrs): + @pytest.mark.parametrize( + "var_name, std_name, long_name, name, attrs", + [ + ( + "var_name", + "height", + "Height", + "var_name", + {"standard_name": "height", "long_name": "Height"}, + ), + ( + None, + "height", + "Height", + "height", + {"standard_name": "height", "long_name": "Height"}, + ), + (None, None, "Height", "Height", {"long_name": "Height"}), + (None, None, None, "unknown", {}), + ], + ) + def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attrs): from iris.cube import Cube from iris.coords import DimCoord - latitude = DimCoord([-90, 0, 90], standard_name=std_name, - var_name=var_name, long_name=long_name) + latitude = DimCoord( + [-90, 0, 90], standard_name=std_name, var_name=var_name, long_name=long_name + ) data = [0, 0, 0] cube = Cube(data, dim_coords_and_dims=[(latitude, 0)]) result = xr.DataArray.from_iris(cube) @@ -4142,40 +4503,38 @@ def test_prevent_duplicate_coord_names(self): # name resolution order a valid iris Cube with coords that have the # same var_name would lead to duplicate dimension names in the # DataArray - longitude = DimCoord([0, 360], standard_name='longitude', - var_name='duplicate') - latitude = DimCoord([-90, 0, 90], standard_name='latitude', - var_name='duplicate') + longitude = DimCoord([0, 360], standard_name="longitude", var_name="duplicate") + latitude = DimCoord( + [-90, 0, 90], standard_name="latitude", var_name="duplicate" + ) data = [[0, 0, 0], [0, 0, 0]] cube = Cube(data, dim_coords_and_dims=[(longitude, 0), (latitude, 1)]) with pytest.raises(ValueError): xr.DataArray.from_iris(cube) @requires_iris - @pytest.mark.parametrize('coord_values', [ - ['IA', 'IL', 'IN'], # non-numeric values - [0, 2, 1], # non-monotonic values - ]) + @pytest.mark.parametrize( + "coord_values", + [["IA", "IL", "IN"], [0, 2, 1]], # non-numeric values # non-monotonic values + ) def test_fallback_to_iris_AuxCoord(self, coord_values): from iris.cube import Cube from iris.coords import AuxCoord data = [0, 0, 0] - da = xr.DataArray(data, coords=[coord_values], dims=['space']) + da = xr.DataArray(data, coords=[coord_values], dims=["space"]) result = xr.DataArray.to_iris(da) - expected = Cube(data, aux_coords_and_dims=[ - (AuxCoord(coord_values, var_name='space'), 0)]) + expected = Cube( + data, aux_coords_and_dims=[(AuxCoord(coord_values, var_name="space"), 0)] + ) assert result == expected @requires_numbagg -@pytest.mark.parametrize('dim', ['time', 'x']) -@pytest.mark.parametrize('window_type, window', [ - ['span', 5], - ['alpha', 0.5], - ['com', 0.5], - ['halflife', 5], -]) +@pytest.mark.parametrize("dim", ["time", "x"]) +@pytest.mark.parametrize( + "window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]] +) def test_rolling_exp(da, dim, window_type, window): da = da.isel(a=0) da = da.where(da > 0.2) @@ -4184,12 +4543,11 @@ def test_rolling_exp(da, dim, window_type, window): assert isinstance(result, DataArray) pandas_array = da.to_pandas() - assert pandas_array.index.name == 'time' - if dim == 'x': + assert pandas_array.index.name == "time" + if dim == "x": pandas_array = pandas_array.T - expected = ( - xr.DataArray(pandas_array.ewm(**{window_type: window}).mean()) - .transpose(*da.dims) + expected = xr.DataArray(pandas_array.ewm(**{window_type: window}).mean()).transpose( + *da.dims ) assert_allclose(expected.variable, result.variable) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 78891045bae..75325a77b36 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -13,17 +13,40 @@ import xarray as xr from xarray import ( - ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align, - backends, broadcast, open_dataset, set_options) + ALL_DIMS, + DataArray, + Dataset, + IndexVariable, + MergeError, + Variable, + align, + backends, + broadcast, + open_dataset, + set_options, +) from xarray.core import dtypes, indexing, npcompat, utils from xarray.core.common import duck_array_ops, full_like from xarray.core.pycompat import integer_types from . import ( - LooseVersion, InaccessibleArray, UnexpectedDataAccess, assert_allclose, - assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, - raises_regex, requires_bottleneck, requires_cftime, requires_dask, - requires_numbagg, requires_scipy, source_ndarray) + LooseVersion, + InaccessibleArray, + UnexpectedDataAccess, + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + has_cftime, + has_dask, + raises_regex, + requires_bottleneck, + requires_cftime, + requires_dask, + requires_numbagg, + requires_scipy, + source_ndarray, +) try: import dask.array as da @@ -33,21 +56,25 @@ def create_test_data(seed=None): rs = np.random.RandomState(seed) - _vars = {'var1': ['dim1', 'dim2'], - 'var2': ['dim1', 'dim2'], - 'var3': ['dim3', 'dim1']} - _dims = {'dim1': 8, 'dim2': 9, 'dim3': 10} + _vars = { + "var1": ["dim1", "dim2"], + "var2": ["dim1", "dim2"], + "var3": ["dim3", "dim1"], + } + _dims = {"dim1": 8, "dim2": 9, "dim3": 10} obj = Dataset() - obj['time'] = ('time', pd.date_range('2000-01-01', periods=20)) - obj['dim2'] = ('dim2', 0.5 * np.arange(_dims['dim2'])) - obj['dim3'] = ('dim3', list('abcdefghij')) + obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) + obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) + obj["dim3"] = ("dim3", list("abcdefghij")) for v, dims in sorted(_vars.items()): data = rs.normal(size=tuple(_dims[d] for d in dims)) - obj[v] = (dims, data, {'foo': 'variable'}) - obj.coords['numbers'] = ('dim3', np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], - dtype='int64')) - obj.encoding = {'foo': 'bar'} + obj[v] = (dims, data, {"foo": "variable"}) + obj.coords["numbers"] = ( + "dim3", + np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), + ) + obj.encoding = {"foo": "bar"} assert all(obj.data.flags.writeable for obj in obj.variables.values()) return obj @@ -59,60 +86,67 @@ def create_append_test_data(seed=None): lon = [0, 1, 2] nt1 = 3 nt2 = 2 - time1 = pd.date_range('2000-01-01', periods=nt1) - time2 = pd.date_range('2000-02-01', periods=nt2) + time1 = pd.date_range("2000-01-01", periods=nt1) + time2 = pd.date_range("2000-02-01", periods=nt2) string_var = np.array(["ae", "bc", "df"], dtype=object) - string_var_to_append = np.array(['asdf', 'asdfg'], dtype=object) + string_var_to_append = np.array(["asdf", "asdfg"], dtype=object) unicode_var = ["áó", "áó", "áó"] ds = xr.Dataset( data_vars={ - 'da': xr.DataArray(rs.rand(3, 3, nt1), coords=[lat, lon, time1], - dims=['lat', 'lon', 'time']), - 'string_var': xr.DataArray(string_var, coords=[time1], - dims=['time']), - 'unicode_var': xr.DataArray(unicode_var, coords=[time1], - dims=['time']).astype(np.unicode_) + "da": xr.DataArray( + rs.rand(3, 3, nt1), + coords=[lat, lon, time1], + dims=["lat", "lon", "time"], + ), + "string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]), + "unicode_var": xr.DataArray( + unicode_var, coords=[time1], dims=["time"] + ).astype(np.unicode_), } ) ds_to_append = xr.Dataset( data_vars={ - 'da': xr.DataArray(rs.rand(3, 3, nt2), coords=[lat, lon, time2], - dims=['lat', 'lon', 'time']), - 'string_var': xr.DataArray(string_var_to_append, coords=[time2], - dims=['time']), - 'unicode_var': xr.DataArray(unicode_var[:nt2], coords=[time2], - dims=['time']).astype(np.unicode_) + "da": xr.DataArray( + rs.rand(3, 3, nt2), + coords=[lat, lon, time2], + dims=["lat", "lon", "time"], + ), + "string_var": xr.DataArray( + string_var_to_append, coords=[time2], dims=["time"] + ), + "unicode_var": xr.DataArray( + unicode_var[:nt2], coords=[time2], dims=["time"] + ).astype(np.unicode_), } ) ds_with_new_var = xr.Dataset( data_vars={ - 'new_var': xr.DataArray( + "new_var": xr.DataArray( rs.rand(3, 3, nt1 + nt2), coords=[lat, lon, time1.append(time2)], - dims=['lat', 'lon', 'time'] - ), + dims=["lat", "lon", "time"], + ) } ) assert all(objp.data.flags.writeable for objp in ds.variables.values()) - assert all( - objp.data.flags.writeable for objp in ds_to_append.variables.values() - ) + assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values()) return ds, ds_to_append, ds_with_new_var def create_test_multiindex(): - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=('level_1', 'level_2')) - return Dataset({}, {'x': mindex}) + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + return Dataset({}, {"x": mindex}) def create_test_stacked_array(): - x = DataArray(pd.Index(np.r_[:10], name='x')) - y = DataArray(pd.Index(np.r_[:20], name='y')) + x = DataArray(pd.Index(np.r_[:10], name="x")) + y = DataArray(pd.Index(np.r_[:20], name="y")) a = x * y b = x * y * y return a, b @@ -133,21 +167,20 @@ def get_variables(self): def lazy_inaccessible(k, v): if k in self._indexvars: return v - data = indexing.LazilyOuterIndexedArray( - InaccessibleArray(v.values)) + data = indexing.LazilyOuterIndexedArray(InaccessibleArray(v.values)) return Variable(v.dims, data, v.attrs) - return { - k: lazy_inaccessible(k, v) - for k, v in self._variables.items() - } + + return {k: lazy_inaccessible(k, v) for k, v in self._variables.items()} class TestDataset: def test_repr(self): data = create_test_data(seed=123) - data.attrs['foo'] = 'bar' + data.attrs["foo"] = "bar" # need to insert str dtype at runtime to handle both Python 2 & 3 - expected = dedent("""\ + expected = ( + dedent( + """\ Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20) Coordinates: @@ -161,42 +194,50 @@ def test_repr(self): var2 (dim1, dim2) float64 1.162 -1.097 -2.123 ... 0.1302 1.267 0.3328 var3 (dim3, dim1) float64 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 Attributes: - foo: bar""") % data['dim3'].dtype # noqa: E501 - actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) + foo: bar""" + ) + % data["dim3"].dtype + ) # noqa: E501 + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) print(actual) assert expected == actual with set_options(display_width=100): - max_len = max(map(len, repr(data).split('\n'))) + max_len = max(map(len, repr(data).split("\n"))) assert 90 < max_len < 100 - expected = dedent("""\ + expected = dedent( + """\ Dimensions: () Data variables: - *empty*""") - actual = '\n'.join(x.rstrip() for x in repr(Dataset()).split('\n')) + *empty*""" + ) + actual = "\n".join(x.rstrip() for x in repr(Dataset()).split("\n")) print(actual) assert expected == actual # verify that ... doesn't appear for scalar coordinates - data = Dataset({'foo': ('x', np.ones(10))}).mean() - expected = dedent("""\ + data = Dataset({"foo": ("x", np.ones(10))}).mean() + expected = dedent( + """\ Dimensions: () Data variables: - foo float64 1.0""") - actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) + foo float64 1.0""" + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) print(actual) assert expected == actual # verify long attributes are truncated - data = Dataset(attrs={'foo': 'bar' * 1000}) + data = Dataset(attrs={"foo": "bar" * 1000}) assert len(repr(data)) < 1000 def test_repr_multiindex(self): data = create_test_multiindex() - expected = dedent("""\ + expected = dedent( + """\ Dimensions: (x: 4) Coordinates: @@ -204,17 +245,19 @@ def test_repr_multiindex(self): - level_1 (x) object 'a' 'a' 'b' 'b' - level_2 (x) int64 1 2 1 2 Data variables: - *empty*""") - actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) + *empty*""" + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) print(actual) assert expected == actual # verify that long level names are not truncated mindex = pd.MultiIndex.from_product( - [['a', 'b'], [1, 2]], - names=('a_quite_long_level_name', 'level_2')) - data = Dataset({}, {'x': mindex}) - expected = dedent("""\ + [["a", "b"], [1, 2]], names=("a_quite_long_level_name", "level_2") + ) + data = Dataset({}, {"x": mindex}) + expected = dedent( + """\ Dimensions: (x: 4) Coordinates: @@ -222,26 +265,27 @@ def test_repr_multiindex(self): - a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' - level_2 (x) int64 1 2 1 2 Data variables: - *empty*""") - actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) + *empty*""" + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) print(actual) assert expected == actual def test_repr_period_index(self): data = create_test_data(seed=456) - data.coords['time'] = pd.period_range( - '2000-01-01', periods=20, freq='B') + data.coords["time"] = pd.period_range("2000-01-01", periods=20, freq="B") # check that creating the repr doesn't raise an error #GH645 repr(data) def test_unicode_data(self): # regression test for GH834 - data = Dataset({'foø': ['ba®']}, attrs={'å': '∑'}) + data = Dataset({"foø": ["ba®"]}, attrs={"å": "∑"}) repr(data) # should not raise - byteorder = '<' if sys.byteorder == 'little' else '>' - expected = dedent("""\ + byteorder = "<" if sys.byteorder == "little" else ">" + expected = dedent( + """\ Dimensions: (foø: 1) Coordinates: @@ -249,20 +293,23 @@ def test_unicode_data(self): Data variables: *empty* Attributes: - å: ∑""" % (byteorder, 'ba®')) + å: ∑""" + % (byteorder, "ba®") + ) actual = str(data) assert expected == actual def test_info(self): ds = create_test_data(seed=123) - ds = ds.drop('dim3') # string type prints differently in PY2 vs PY3 - ds.attrs['unicode_attr'] = 'ba®' - ds.attrs['string_attr'] = 'bar' + ds = ds.drop("dim3") # string type prints differently in PY2 vs PY3 + ds.attrs["unicode_attr"] = "ba®" + ds.attrs["string_attr"] = "bar" buf = StringIO() ds.info(buf=buf) - expected = dedent('''\ + expected = dedent( + """\ xarray.Dataset { dimensions: \tdim1 = 8 ; @@ -284,185 +331,201 @@ def test_info(self): // global attributes: \t:unicode_attr = ba® ; \t:string_attr = bar ; - }''') + }""" + ) actual = buf.getvalue() assert expected == actual buf.close() def test_constructor(self): - x1 = ('x', 2 * np.arange(100)) - x2 = ('x', np.arange(1000)) - z = (['x', 'y'], np.arange(1000).reshape(100, 10)) + x1 = ("x", 2 * np.arange(100)) + x2 = ("x", np.arange(1000)) + z = (["x", "y"], np.arange(1000).reshape(100, 10)) - with raises_regex(ValueError, 'conflicting sizes'): - Dataset({'a': x1, 'b': x2}) + with raises_regex(ValueError, "conflicting sizes"): + Dataset({"a": x1, "b": x2}) with raises_regex(ValueError, "disallows such variables"): - Dataset({'a': x1, 'x': z}) - with raises_regex(TypeError, 'tuple of form'): - Dataset({'x': (1, 2, 3, 4, 5, 6, 7)}) - with raises_regex(ValueError, 'already exists as a scalar'): - Dataset({'x': 0, 'y': ('x', [1, 2, 3])}) + Dataset({"a": x1, "x": z}) + with raises_regex(TypeError, "tuple of form"): + Dataset({"x": (1, 2, 3, 4, 5, 6, 7)}) + with raises_regex(ValueError, "already exists as a scalar"): + Dataset({"x": 0, "y": ("x", [1, 2, 3])}) # verify handling of DataArrays - expected = Dataset({'x': x1, 'z': z}) - actual = Dataset({'z': expected['z']}) + expected = Dataset({"x": x1, "z": z}) + actual = Dataset({"z": expected["z"]}) assert_identical(expected, actual) def test_constructor_invalid_dims(self): # regression for GH1120 with pytest.raises(MergeError): - Dataset(data_vars=dict(v=('y', [1, 2, 3, 4])), - coords=dict(y=DataArray([.1, .2, .3, .4], dims='x'))) + Dataset( + data_vars=dict(v=("y", [1, 2, 3, 4])), + coords=dict(y=DataArray([0.1, 0.2, 0.3, 0.4], dims="x")), + ) def test_constructor_1d(self): - expected = Dataset({'x': (['x'], 5.0 + np.arange(5))}) - actual = Dataset({'x': 5.0 + np.arange(5)}) + expected = Dataset({"x": (["x"], 5.0 + np.arange(5))}) + actual = Dataset({"x": 5.0 + np.arange(5)}) assert_identical(expected, actual) - actual = Dataset({'x': [5, 6, 7, 8, 9]}) + actual = Dataset({"x": [5, 6, 7, 8, 9]}) assert_identical(expected, actual) def test_constructor_0d(self): - expected = Dataset({'x': ([], 1)}) - for arg in [1, np.array(1), expected['x']]: - actual = Dataset({'x': arg}) + expected = Dataset({"x": ([], 1)}) + for arg in [1, np.array(1), expected["x"]]: + actual = Dataset({"x": arg}) assert_identical(expected, actual) class Arbitrary: pass - d = pd.Timestamp('2000-01-01T12') - args = [True, None, 3.4, np.nan, 'hello', b'raw', - np.datetime64('2000-01-01'), d, d.to_pydatetime(), - Arbitrary()] + d = pd.Timestamp("2000-01-01T12") + args = [ + True, + None, + 3.4, + np.nan, + "hello", + b"raw", + np.datetime64("2000-01-01"), + d, + d.to_pydatetime(), + Arbitrary(), + ] for arg in args: print(arg) - expected = Dataset({'x': ([], arg)}) - actual = Dataset({'x': arg}) + expected = Dataset({"x": ([], arg)}) + actual = Dataset({"x": arg}) assert_identical(expected, actual) def test_constructor_deprecated(self): - with raises_regex(ValueError, 'DataArray dimensions'): - DataArray([1, 2, 3], coords={'x': [0, 1, 2]}) + with raises_regex(ValueError, "DataArray dimensions"): + DataArray([1, 2, 3], coords={"x": [0, 1, 2]}) def test_constructor_auto_align(self): - a = DataArray([1, 2], [('x', [0, 1])]) - b = DataArray([3, 4], [('x', [1, 2])]) + a = DataArray([1, 2], [("x", [0, 1])]) + b = DataArray([3, 4], [("x", [1, 2])]) # verify align uses outer join - expected = Dataset({'a': ('x', [1, 2, np.nan]), - 'b': ('x', [np.nan, 3, 4])}, - {'x': [0, 1, 2]}) - actual = Dataset({'a': a, 'b': b}) + expected = Dataset( + {"a": ("x", [1, 2, np.nan]), "b": ("x", [np.nan, 3, 4])}, {"x": [0, 1, 2]} + ) + actual = Dataset({"a": a, "b": b}) assert_identical(expected, actual) # regression test for GH346 - assert isinstance(actual.variables['x'], IndexVariable) + assert isinstance(actual.variables["x"], IndexVariable) # variable with different dimensions - c = ('y', [3, 4]) - expected2 = expected.merge({'c': c}) - actual = Dataset({'a': a, 'b': b, 'c': c}) + c = ("y", [3, 4]) + expected2 = expected.merge({"c": c}) + actual = Dataset({"a": a, "b": b, "c": c}) assert_identical(expected2, actual) # variable that is only aligned against the aligned variables - d = ('x', [3, 2, 1]) - expected3 = expected.merge({'d': d}) - actual = Dataset({'a': a, 'b': b, 'd': d}) + d = ("x", [3, 2, 1]) + expected3 = expected.merge({"d": d}) + actual = Dataset({"a": a, "b": b, "d": d}) assert_identical(expected3, actual) - e = ('x', [0, 0]) - with raises_regex(ValueError, 'conflicting sizes'): - Dataset({'a': a, 'b': b, 'e': e}) + e = ("x", [0, 0]) + with raises_regex(ValueError, "conflicting sizes"): + Dataset({"a": a, "b": b, "e": e}) def test_constructor_pandas_sequence(self): ds = self.make_example_math_dataset() pandas_objs = OrderedDict( - (var_name, ds[var_name].to_pandas()) for var_name in ['foo', 'bar'] + (var_name, ds[var_name].to_pandas()) for var_name in ["foo", "bar"] ) ds_based_on_pandas = Dataset(pandas_objs, ds.coords, attrs=ds.attrs) - del ds_based_on_pandas['x'] + del ds_based_on_pandas["x"] assert_equal(ds, ds_based_on_pandas) # reindex pandas obj, check align works - rearranged_index = reversed(pandas_objs['foo'].index) - pandas_objs['foo'] = pandas_objs['foo'].reindex(rearranged_index) + rearranged_index = reversed(pandas_objs["foo"].index) + pandas_objs["foo"] = pandas_objs["foo"].reindex(rearranged_index) ds_based_on_pandas = Dataset(pandas_objs, ds.coords, attrs=ds.attrs) - del ds_based_on_pandas['x'] + del ds_based_on_pandas["x"] assert_equal(ds, ds_based_on_pandas) def test_constructor_pandas_single(self): das = [ - DataArray(np.random.rand(4), dims=['a']), # series - DataArray(np.random.rand(4, 3), dims=['a', 'b']), # df + DataArray(np.random.rand(4), dims=["a"]), # series + DataArray(np.random.rand(4, 3), dims=["a", "b"]), # df ] - if LooseVersion(pd.__version__) < '0.25.0': - das.append( - DataArray(np.random.rand(4, 3, 2), dims=['a', 'b', 'c'])) + if LooseVersion(pd.__version__) < "0.25.0": + das.append(DataArray(np.random.rand(4, 3, 2), dims=["a", "b", "c"])) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + warnings.filterwarnings("ignore", r"\W*Panel is deprecated") for a in das: pandas_obj = a.to_pandas() ds_based_on_pandas = Dataset(pandas_obj) for dim in ds_based_on_pandas.data_vars: - assert_array_equal( - ds_based_on_pandas[dim], pandas_obj[dim]) + assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) def test_constructor_compat(self): - data = OrderedDict([('x', DataArray(0, coords={'y': 1})), - ('y', ('z', [1, 1, 1]))]) - expected = Dataset({'x': 0}, {'y': ('z', [1, 1, 1])}) + data = OrderedDict( + [("x", DataArray(0, coords={"y": 1})), ("y", ("z", [1, 1, 1]))] + ) + expected = Dataset({"x": 0}, {"y": ("z", [1, 1, 1])}) actual = Dataset(data) assert_identical(expected, actual) - data = OrderedDict([('y', ('z', [1, 1, 1])), - ('x', DataArray(0, coords={'y': 1}))]) + data = OrderedDict( + [("y", ("z", [1, 1, 1])), ("x", DataArray(0, coords={"y": 1}))] + ) actual = Dataset(data) assert_identical(expected, actual) - original = Dataset({'a': (('x', 'y'), np.ones((2, 3)))}, - {'c': (('x', 'y'), np.zeros((2, 3))), 'x': [0, 1]}) - expected = Dataset({'a': ('x', np.ones(2)), - 'b': ('y', np.ones(3))}, - {'c': (('x', 'y'), np.zeros((2, 3))), 'x': [0, 1]}) + original = Dataset( + {"a": (("x", "y"), np.ones((2, 3)))}, + {"c": (("x", "y"), np.zeros((2, 3))), "x": [0, 1]}, + ) + expected = Dataset( + {"a": ("x", np.ones(2)), "b": ("y", np.ones(3))}, + {"c": (("x", "y"), np.zeros((2, 3))), "x": [0, 1]}, + ) # use an OrderedDict to ensure test results are reproducible; otherwise # the order of appearance of x and y matters for the order of # dimensions in 'c' - actual = Dataset(OrderedDict([('a', original['a'][:, 0]), - ('b', original['a'][0].drop('x'))])) + actual = Dataset( + OrderedDict([("a", original["a"][:, 0]), ("b", original["a"][0].drop("x"))]) + ) assert_identical(expected, actual) - data = {'x': DataArray(0, coords={'y': 3}), 'y': ('z', [1, 1, 1])} + data = {"x": DataArray(0, coords={"y": 3}), "y": ("z", [1, 1, 1])} with pytest.raises(MergeError): Dataset(data) - data = {'x': DataArray(0, coords={'y': 1}), 'y': [1, 1]} + data = {"x": DataArray(0, coords={"y": 1}), "y": [1, 1]} actual = Dataset(data) - expected = Dataset({'x': 0}, {'y': [1, 1]}) + expected = Dataset({"x": 0}, {"y": [1, 1]}) assert_identical(expected, actual) def test_constructor_with_coords(self): - with raises_regex(ValueError, 'found in both data_vars and'): - Dataset({'a': ('x', [1])}, {'a': ('x', [1])}) + with raises_regex(ValueError, "found in both data_vars and"): + Dataset({"a": ("x", [1])}, {"a": ("x", [1])}) - ds = Dataset({}, {'a': ('x', [1])}) + ds = Dataset({}, {"a": ("x", [1])}) assert not ds.data_vars - assert list(ds.coords.keys()) == ['a'] + assert list(ds.coords.keys()) == ["a"] - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=('level_1', 'level_2')) - with raises_regex(ValueError, 'conflicting MultiIndex'): - Dataset({}, {'x': mindex, 'y': mindex}) - Dataset({}, {'x': mindex, 'level_1': range(4)}) + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + with raises_regex(ValueError, "conflicting MultiIndex"): + Dataset({}, {"x": mindex, "y": mindex}) + Dataset({}, {"x": mindex, "level_1": range(4)}) def test_properties(self): ds = create_test_data() - assert ds.dims == \ - {'dim1': 8, 'dim2': 9, 'dim3': 10, 'time': 20} + assert ds.dims == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20} assert list(ds.dims) == sorted(ds.dims) assert ds.sizes == ds.dims @@ -474,212 +537,219 @@ def test_properties(self): assert list(ds) == list(ds.data_vars) assert list(ds.keys()) == list(ds.data_vars) - assert 'aasldfjalskdfj' not in ds.variables - assert 'dim1' in repr(ds.variables) + assert "aasldfjalskdfj" not in ds.variables + assert "dim1" in repr(ds.variables) assert len(ds) == 3 assert bool(ds) - assert list(ds.data_vars) == ['var1', 'var2', 'var3'] - assert list(ds.data_vars.keys()) == ['var1', 'var2', 'var3'] - assert 'var1' in ds.data_vars - assert 'dim1' not in ds.data_vars - assert 'numbers' not in ds.data_vars + assert list(ds.data_vars) == ["var1", "var2", "var3"] + assert list(ds.data_vars.keys()) == ["var1", "var2", "var3"] + assert "var1" in ds.data_vars + assert "dim1" not in ds.data_vars + assert "numbers" not in ds.data_vars assert len(ds.data_vars) == 3 - assert set(ds.indexes) == {'dim2', 'dim3', 'time'} + assert set(ds.indexes) == {"dim2", "dim3", "time"} assert len(ds.indexes) == 3 - assert 'dim2' in repr(ds.indexes) + assert "dim2" in repr(ds.indexes) - assert list(ds.coords) == ['time', 'dim2', 'dim3', 'numbers'] - assert 'dim2' in ds.coords - assert 'numbers' in ds.coords - assert 'var1' not in ds.coords - assert 'dim1' not in ds.coords + assert list(ds.coords) == ["time", "dim2", "dim3", "numbers"] + assert "dim2" in ds.coords + assert "numbers" in ds.coords + assert "var1" not in ds.coords + assert "dim1" not in ds.coords assert len(ds.coords) == 4 - assert Dataset({'x': np.int64(1), - 'y': np.float32([1, 2])}).nbytes == 16 + assert Dataset({"x": np.int64(1), "y": np.float32([1, 2])}).nbytes == 16 def test_asarray(self): - ds = Dataset({'x': 0}) - with raises_regex(TypeError, 'cannot directly convert'): + ds = Dataset({"x": 0}) + with raises_regex(TypeError, "cannot directly convert"): np.asarray(ds) def test_get_index(self): - ds = Dataset({'foo': (('x', 'y'), np.zeros((2, 3)))}, - coords={'x': ['a', 'b']}) - assert ds.get_index('x').equals(pd.Index(['a', 'b'])) - assert ds.get_index('y').equals(pd.Index([0, 1, 2])) + ds = Dataset({"foo": (("x", "y"), np.zeros((2, 3)))}, coords={"x": ["a", "b"]}) + assert ds.get_index("x").equals(pd.Index(["a", "b"])) + assert ds.get_index("y").equals(pd.Index([0, 1, 2])) with pytest.raises(KeyError): - ds.get_index('z') + ds.get_index("z") def test_attr_access(self): - ds = Dataset({'tmin': ('x', [42], {'units': 'Celcius'})}, - attrs={'title': 'My test data'}) - assert_identical(ds.tmin, ds['tmin']) + ds = Dataset( + {"tmin": ("x", [42], {"units": "Celcius"})}, attrs={"title": "My test data"} + ) + assert_identical(ds.tmin, ds["tmin"]) assert_identical(ds.tmin.x, ds.x) - assert ds.title == ds.attrs['title'] - assert ds.tmin.units == ds['tmin'].attrs['units'] + assert ds.title == ds.attrs["title"] + assert ds.tmin.units == ds["tmin"].attrs["units"] - assert {'tmin', 'title'} <= set(dir(ds)) - assert 'units' in set(dir(ds.tmin)) + assert {"tmin", "title"} <= set(dir(ds)) + assert "units" in set(dir(ds.tmin)) # should defer to variable of same name - ds.attrs['tmin'] = -999 - assert ds.attrs['tmin'] == -999 - assert_identical(ds.tmin, ds['tmin']) + ds.attrs["tmin"] = -999 + assert ds.attrs["tmin"] == -999 + assert_identical(ds.tmin, ds["tmin"]) def test_variable(self): a = Dataset() d = np.random.random((10, 3)) - a['foo'] = (('time', 'x',), d) - assert 'foo' in a.variables - assert 'foo' in a - a['bar'] = (('time', 'x',), d) + a["foo"] = (("time", "x"), d) + assert "foo" in a.variables + assert "foo" in a + a["bar"] = (("time", "x"), d) # order of creation is preserved - assert list(a.variables) == ['foo', 'bar'] - assert_array_equal(a['foo'].values, d) + assert list(a.variables) == ["foo", "bar"] + assert_array_equal(a["foo"].values, d) # try to add variable with dim (10,3) with data that's (3,10) with pytest.raises(ValueError): - a['qux'] = (('time', 'x'), d.T) + a["qux"] = (("time", "x"), d.T) def test_modify_inplace(self): a = Dataset() vec = np.random.random((10,)) - attributes = {'foo': 'bar'} - a['x'] = ('x', vec, attributes) - assert 'x' in a.coords - assert isinstance(a.coords['x'].to_index(), pd.Index) - assert_identical(a.coords['x'].variable, a.variables['x']) + attributes = {"foo": "bar"} + a["x"] = ("x", vec, attributes) + assert "x" in a.coords + assert isinstance(a.coords["x"].to_index(), pd.Index) + assert_identical(a.coords["x"].variable, a.variables["x"]) b = Dataset() - b['x'] = ('x', vec, attributes) - assert_identical(a['x'], b['x']) + b["x"] = ("x", vec, attributes) + assert_identical(a["x"], b["x"]) assert a.dims == b.dims # this should work - a['x'] = ('x', vec[:5]) - a['z'] = ('x', np.arange(5)) + a["x"] = ("x", vec[:5]) + a["z"] = ("x", np.arange(5)) with pytest.raises(ValueError): # now it shouldn't, since there is a conflicting length - a['x'] = ('x', vec[:4]) - arr = np.random.random((10, 1,)) + a["x"] = ("x", vec[:4]) + arr = np.random.random((10, 1)) scal = np.array(0) with pytest.raises(ValueError): - a['y'] = ('y', arr) + a["y"] = ("y", arr) with pytest.raises(ValueError): - a['y'] = ('y', scal) - assert 'y' not in a.dims + a["y"] = ("y", scal) + assert "y" not in a.dims def test_coords_properties(self): # use an OrderedDict for coordinates to ensure order across python # versions # use int64 for repr consistency on windows - data = Dataset(OrderedDict([('x', ('x', np.array([-1, -2], 'int64'))), - ('y', ('y', np.array([0, 1, 2], 'int64'))), - ('foo', (['x', 'y'], - np.random.randn(2, 3)))]), - OrderedDict([('a', ('x', np.array([4, 5], 'int64'))), - ('b', np.int64(-10))])) + data = Dataset( + OrderedDict( + [ + ("x", ("x", np.array([-1, -2], "int64"))), + ("y", ("y", np.array([0, 1, 2], "int64"))), + ("foo", (["x", "y"], np.random.randn(2, 3))), + ] + ), + OrderedDict( + [("a", ("x", np.array([4, 5], "int64"))), ("b", np.int64(-10))] + ), + ) assert 4 == len(data.coords) - assert ['x', 'y', 'a', 'b'] == list(data.coords) + assert ["x", "y", "a", "b"] == list(data.coords) - assert_identical(data.coords['x'].variable, data['x'].variable) - assert_identical(data.coords['y'].variable, data['y'].variable) + assert_identical(data.coords["x"].variable, data["x"].variable) + assert_identical(data.coords["y"].variable, data["y"].variable) - assert 'x' in data.coords - assert 'a' in data.coords + assert "x" in data.coords + assert "a" in data.coords assert 0 not in data.coords - assert 'foo' not in data.coords + assert "foo" not in data.coords with pytest.raises(KeyError): - data.coords['foo'] + data.coords["foo"] with pytest.raises(KeyError): data.coords[0] - expected = dedent("""\ + expected = dedent( + """\ Coordinates: * x (x) int64 -1 -2 * y (y) int64 0 1 2 a (x) int64 4 5 - b int64 -10""") + b int64 -10""" + ) actual = repr(data.coords) assert expected == actual - assert {'x': 2, 'y': 3} == data.coords.dims + assert {"x": 2, "y": 3} == data.coords.dims def test_coords_modify(self): - data = Dataset({'x': ('x', [-1, -2]), - 'y': ('y', [0, 1, 2]), - 'foo': (['x', 'y'], np.random.randn(2, 3))}, - {'a': ('x', [4, 5]), 'b': -10}) + data = Dataset( + { + "x": ("x", [-1, -2]), + "y": ("y", [0, 1, 2]), + "foo": (["x", "y"], np.random.randn(2, 3)), + }, + {"a": ("x", [4, 5]), "b": -10}, + ) actual = data.copy(deep=True) - actual.coords['x'] = ('x', ['a', 'b']) - assert_array_equal(actual['x'], ['a', 'b']) + actual.coords["x"] = ("x", ["a", "b"]) + assert_array_equal(actual["x"], ["a", "b"]) actual = data.copy(deep=True) - actual.coords['z'] = ('z', ['a', 'b']) - assert_array_equal(actual['z'], ['a', 'b']) + actual.coords["z"] = ("z", ["a", "b"]) + assert_array_equal(actual["z"], ["a", "b"]) actual = data.copy(deep=True) - with raises_regex(ValueError, 'conflicting sizes'): - actual.coords['x'] = ('x', [-1]) + with raises_regex(ValueError, "conflicting sizes"): + actual.coords["x"] = ("x", [-1]) assert_identical(actual, data) # should not be modified actual = data.copy() - del actual.coords['b'] - expected = data.reset_coords('b', drop=True) + del actual.coords["b"] + expected = data.reset_coords("b", drop=True) assert_identical(expected, actual) with pytest.raises(KeyError): - del data.coords['not_found'] + del data.coords["not_found"] with pytest.raises(KeyError): - del data.coords['foo'] + del data.coords["foo"] actual = data.copy(deep=True) - actual.coords.update({'c': 11}) - expected = data.merge({'c': 11}).set_coords('c') + actual.coords.update({"c": 11}) + expected = data.merge({"c": 11}).set_coords("c") assert_identical(expected, actual) def test_update_index(self): - actual = Dataset(coords={'x': [1, 2, 3]}) - actual['x'] = ['a', 'b', 'c'] - assert actual.indexes['x'].equals(pd.Index(['a', 'b', 'c'])) + actual = Dataset(coords={"x": [1, 2, 3]}) + actual["x"] = ["a", "b", "c"] + assert actual.indexes["x"].equals(pd.Index(["a", "b", "c"])) def test_coords_setitem_with_new_dimension(self): actual = Dataset() - actual.coords['foo'] = ('x', [1, 2, 3]) - expected = Dataset(coords={'foo': ('x', [1, 2, 3])}) + actual.coords["foo"] = ("x", [1, 2, 3]) + expected = Dataset(coords={"foo": ("x", [1, 2, 3])}) assert_identical(expected, actual) def test_coords_setitem_multiindex(self): data = create_test_multiindex() - with raises_regex(ValueError, 'conflicting MultiIndex'): - data.coords['level_1'] = range(4) + with raises_regex(ValueError, "conflicting MultiIndex"): + data.coords["level_1"] = range(4) def test_coords_set(self): - one_coord = Dataset({'x': ('x', [0]), - 'yy': ('x', [1]), - 'zzz': ('x', [2])}) - two_coords = Dataset({'zzz': ('x', [2])}, - {'x': ('x', [0]), - 'yy': ('x', [1])}) - all_coords = Dataset(coords={'x': ('x', [0]), - 'yy': ('x', [1]), - 'zzz': ('x', [2])}) - - actual = one_coord.set_coords('x') + one_coord = Dataset({"x": ("x", [0]), "yy": ("x", [1]), "zzz": ("x", [2])}) + two_coords = Dataset({"zzz": ("x", [2])}, {"x": ("x", [0]), "yy": ("x", [1])}) + all_coords = Dataset( + coords={"x": ("x", [0]), "yy": ("x", [1]), "zzz": ("x", [2])} + ) + + actual = one_coord.set_coords("x") assert_identical(one_coord, actual) - actual = one_coord.set_coords(['x']) + actual = one_coord.set_coords(["x"]) assert_identical(one_coord, actual) - actual = one_coord.set_coords('yy') + actual = one_coord.set_coords("yy") assert_identical(two_coords, actual) - actual = one_coord.set_coords(['yy', 'zzz']) + actual = one_coord.set_coords(["yy", "zzz"]) assert_identical(all_coords, actual) actual = one_coord.reset_coords() @@ -689,96 +759,97 @@ def test_coords_set(self): actual = all_coords.reset_coords() assert_identical(one_coord, actual) - actual = all_coords.reset_coords(['yy', 'zzz']) + actual = all_coords.reset_coords(["yy", "zzz"]) assert_identical(one_coord, actual) - actual = all_coords.reset_coords('zzz') + actual = all_coords.reset_coords("zzz") assert_identical(two_coords, actual) - with raises_regex(ValueError, 'cannot remove index'): - one_coord.reset_coords('x') + with raises_regex(ValueError, "cannot remove index"): + one_coord.reset_coords("x") - actual = all_coords.reset_coords('zzz', drop=True) - expected = all_coords.drop('zzz') + actual = all_coords.reset_coords("zzz", drop=True) + expected = all_coords.drop("zzz") assert_identical(expected, actual) - expected = two_coords.drop('zzz') + expected = two_coords.drop("zzz") assert_identical(expected, actual) def test_coords_to_dataset(self): - orig = Dataset({'foo': ('y', [-1, 0, 1])}, {'x': 10, 'y': [2, 3, 4]}) - expected = Dataset(coords={'x': 10, 'y': [2, 3, 4]}) + orig = Dataset({"foo": ("y", [-1, 0, 1])}, {"x": 10, "y": [2, 3, 4]}) + expected = Dataset(coords={"x": 10, "y": [2, 3, 4]}) actual = orig.coords.to_dataset() assert_identical(expected, actual) def test_coords_merge(self): - orig_coords = Dataset(coords={'a': ('x', [1, 2]), 'x': [0, 1]}).coords - other_coords = Dataset(coords={'b': ('x', ['a', 'b']), - 'x': [0, 1]}).coords - expected = Dataset(coords={'a': ('x', [1, 2]), - 'b': ('x', ['a', 'b']), - 'x': [0, 1]}) + orig_coords = Dataset(coords={"a": ("x", [1, 2]), "x": [0, 1]}).coords + other_coords = Dataset(coords={"b": ("x", ["a", "b"]), "x": [0, 1]}).coords + expected = Dataset( + coords={"a": ("x", [1, 2]), "b": ("x", ["a", "b"]), "x": [0, 1]} + ) actual = orig_coords.merge(other_coords) assert_identical(expected, actual) actual = other_coords.merge(orig_coords) assert_identical(expected, actual) - other_coords = Dataset(coords={'x': ('x', ['a'])}).coords + other_coords = Dataset(coords={"x": ("x", ["a"])}).coords with pytest.raises(MergeError): orig_coords.merge(other_coords) - other_coords = Dataset(coords={'x': ('x', ['a', 'b'])}).coords + other_coords = Dataset(coords={"x": ("x", ["a", "b"])}).coords with pytest.raises(MergeError): orig_coords.merge(other_coords) - other_coords = Dataset(coords={'x': ('x', ['a', 'b', 'c'])}).coords + other_coords = Dataset(coords={"x": ("x", ["a", "b", "c"])}).coords with pytest.raises(MergeError): orig_coords.merge(other_coords) - other_coords = Dataset(coords={'a': ('x', [8, 9])}).coords - expected = Dataset(coords={'x': range(2)}) + other_coords = Dataset(coords={"a": ("x", [8, 9])}).coords + expected = Dataset(coords={"x": range(2)}) actual = orig_coords.merge(other_coords) assert_identical(expected, actual) actual = other_coords.merge(orig_coords) assert_identical(expected, actual) - other_coords = Dataset(coords={'x': np.nan}).coords + other_coords = Dataset(coords={"x": np.nan}).coords actual = orig_coords.merge(other_coords) assert_identical(orig_coords.to_dataset(), actual) actual = other_coords.merge(orig_coords) assert_identical(orig_coords.to_dataset(), actual) def test_coords_merge_mismatched_shape(self): - orig_coords = Dataset(coords={'a': ('x', [1, 1])}).coords - other_coords = Dataset(coords={'a': 1}).coords + orig_coords = Dataset(coords={"a": ("x", [1, 1])}).coords + other_coords = Dataset(coords={"a": 1}).coords expected = orig_coords.to_dataset() actual = orig_coords.merge(other_coords) assert_identical(expected, actual) - other_coords = Dataset(coords={'a': ('y', [1])}).coords - expected = Dataset(coords={'a': (['x', 'y'], [[1], [1]])}) + other_coords = Dataset(coords={"a": ("y", [1])}).coords + expected = Dataset(coords={"a": (["x", "y"], [[1], [1]])}) actual = orig_coords.merge(other_coords) assert_identical(expected, actual) actual = other_coords.merge(orig_coords) assert_identical(expected.transpose(), actual) - orig_coords = Dataset(coords={'a': ('x', [np.nan])}).coords - other_coords = Dataset(coords={'a': np.nan}).coords + orig_coords = Dataset(coords={"a": ("x", [np.nan])}).coords + other_coords = Dataset(coords={"a": np.nan}).coords expected = orig_coords.to_dataset() actual = orig_coords.merge(other_coords) assert_identical(expected, actual) def test_data_vars_properties(self): ds = Dataset() - ds['foo'] = (('x',), [1.0]) - ds['bar'] = 2.0 + ds["foo"] = (("x",), [1.0]) + ds["bar"] = 2.0 - assert set(ds.data_vars) == {'foo', 'bar'} - assert 'foo' in ds.data_vars - assert 'x' not in ds.data_vars - assert_identical(ds['foo'], ds.data_vars['foo']) + assert set(ds.data_vars) == {"foo", "bar"} + assert "foo" in ds.data_vars + assert "x" not in ds.data_vars + assert_identical(ds["foo"], ds.data_vars["foo"]) - expected = dedent("""\ + expected = dedent( + """\ Data variables: foo (x) float64 1.0 - bar float64 2.0""") + bar float64 2.0""" + ) actual = repr(ds.data_vars) assert expected == actual @@ -788,14 +859,14 @@ def test_equals_and_identical(self): assert data.identical(data) data2 = create_test_data(seed=42) - data2.attrs['foobar'] = 'baz' + data2.attrs["foobar"] = "baz" assert data.equals(data2) assert not data.identical(data2) - del data2['time'] + del data2["time"] assert not data.equals(data2) - data = create_test_data(seed=42).rename({'var1': None}) + data = create_test_data(seed=42).rename({"var1": None}) assert data.equals(data) assert data.identical(data) @@ -805,21 +876,21 @@ def test_equals_and_identical(self): def test_equals_failures(self): data = create_test_data() - assert not data.equals('foo') + assert not data.equals("foo") assert not data.identical(123) assert not data.broadcast_equals({1: 2}) def test_broadcast_equals(self): - data1 = Dataset(coords={'x': 0}) - data2 = Dataset(coords={'x': [0]}) + data1 = Dataset(coords={"x": 0}) + data2 = Dataset(coords={"x": [0]}) assert data1.broadcast_equals(data2) assert not data1.equals(data2) assert not data1.identical(data2) def test_attrs(self): data = create_test_data(seed=42) - data.attrs = {'foobar': 'baz'} - assert data.attrs['foobar'], 'baz' + data.attrs = {"foobar": "baz"} + assert data.attrs["foobar"], "baz" assert isinstance(data.attrs, OrderedDict) @requires_dask @@ -836,13 +907,13 @@ def test_chunk(self): else: assert isinstance(v.data, da.Array) - expected_chunks = {'dim1': (8,), 'dim2': (9,), 'dim3': (10,)} + expected_chunks = {"dim1": (8,), "dim2": (9,), "dim3": (10,)} assert reblocked.chunks == expected_chunks - reblocked = data.chunk({'time': 5, 'dim1': 5, 'dim2': 5, 'dim3': 5}) + reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5}) # time is not a dim in any of the data_vars, so it # doesn't get chunked - expected_chunks = {'dim1': (5, 3), 'dim2': (5, 4), 'dim3': (5, 5)} + expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)} assert reblocked.chunks == expected_chunks reblocked = data.chunk(expected_chunks) @@ -853,8 +924,8 @@ def test_chunk(self): assert reblocked.chunks == expected_chunks assert_identical(reblocked, data) - with raises_regex(ValueError, 'some chunks'): - data.chunk({'foo': 10}) + with raises_regex(ValueError, "some chunks"): + data.chunk({"foo": 10}) @requires_dask def test_dask_is_lazy(self): @@ -865,7 +936,7 @@ def test_dask_is_lazy(self): with pytest.raises(UnexpectedDataAccess): ds.load() with pytest.raises(UnexpectedDataAccess): - ds['var1'].values + ds["var1"].values # these should not raise UnexpectedDataAccess: ds.var1.data @@ -874,21 +945,20 @@ def test_dask_is_lazy(self): ds.transpose() ds.mean() ds.fillna(0) - ds.rename({'dim1': 'foobar'}) - ds.set_coords('var1') - ds.drop('var1') + ds.rename({"dim1": "foobar"}) + ds.set_coords("var1") + ds.drop("var1") def test_isel(self): data = create_test_data() - slicers = {'dim1': slice(None, None, 2), 'dim2': slice(0, 2)} + slicers = {"dim1": slice(None, None, 2), "dim2": slice(0, 2)} ret = data.isel(**slicers) # Verify that only the specified dimension was altered assert list(data.dims) == list(ret.dims) for d in data.dims: if d in slicers: - assert ret.dims[d] == \ - np.arange(data.dims[d])[slicers[d]].size + assert ret.dims[d] == np.arange(data.dims[d])[slicers[d]].size else: assert data.dims[d] == ret.dims[d] # Verify that the data is what we expect @@ -909,22 +979,22 @@ def test_isel(self): data.isel(not_a_dim=slice(0, 2)) ret = data.isel(dim1=0) - assert {'time': 20, 'dim2': 9, 'dim3': 10} == ret.dims + assert {"time": 20, "dim2": 9, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) assert set(data.indexes) == set(ret.indexes) ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) - assert {'time': 2, 'dim2': 5, 'dim3': 10} == ret.dims + assert {"time": 2, "dim2": 5, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) assert set(data.indexes) == set(ret.indexes) ret = data.isel(time=0, dim1=0, dim2=slice(5)) - assert {'dim2': 5, 'dim3': 10} == ret.dims + assert {"dim2": 5, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(list(ret.indexes) + ['time']) + assert set(data.indexes) == set(list(ret.indexes) + ["time"]) def test_isel_fancy(self): # isel with fancy indexing. @@ -933,364 +1003,387 @@ def test_isel_fancy(self): pdim1 = [1, 2, 3] pdim2 = [4, 5, 1] pdim3 = [1, 2, 3] - actual = data.isel(dim1=(('test_coord', ), pdim1), - dim2=(('test_coord', ), pdim2), - dim3=(('test_coord', ), pdim3)) - assert 'test_coord' in actual.dims - assert actual.coords['test_coord'].shape == (len(pdim1), ) + actual = data.isel( + dim1=(("test_coord",), pdim1), + dim2=(("test_coord",), pdim2), + dim3=(("test_coord",), pdim3), + ) + assert "test_coord" in actual.dims + assert actual.coords["test_coord"].shape == (len(pdim1),) # Should work with DataArray - actual = data.isel(dim1=DataArray(pdim1, dims='test_coord'), - dim2=(('test_coord', ), pdim2), - dim3=(('test_coord', ), pdim3)) - assert 'test_coord' in actual.dims - assert actual.coords['test_coord'].shape == (len(pdim1), ) - expected = data.isel(dim1=(('test_coord', ), pdim1), - dim2=(('test_coord', ), pdim2), - dim3=(('test_coord', ), pdim3)) + actual = data.isel( + dim1=DataArray(pdim1, dims="test_coord"), + dim2=(("test_coord",), pdim2), + dim3=(("test_coord",), pdim3), + ) + assert "test_coord" in actual.dims + assert actual.coords["test_coord"].shape == (len(pdim1),) + expected = data.isel( + dim1=(("test_coord",), pdim1), + dim2=(("test_coord",), pdim2), + dim3=(("test_coord",), pdim3), + ) assert_identical(actual, expected) # DataArray with coordinate - idx1 = DataArray(pdim1, dims=['a'], coords={'a': np.random.randn(3)}) - idx2 = DataArray(pdim2, dims=['b'], coords={'b': np.random.randn(3)}) - idx3 = DataArray(pdim3, dims=['c'], coords={'c': np.random.randn(3)}) + idx1 = DataArray(pdim1, dims=["a"], coords={"a": np.random.randn(3)}) + idx2 = DataArray(pdim2, dims=["b"], coords={"b": np.random.randn(3)}) + idx3 = DataArray(pdim3, dims=["c"], coords={"c": np.random.randn(3)}) # Should work with DataArray actual = data.isel(dim1=idx1, dim2=idx2, dim3=idx3) - assert 'a' in actual.dims - assert 'b' in actual.dims - assert 'c' in actual.dims - assert 'time' in actual.coords - assert 'dim2' in actual.coords - assert 'dim3' in actual.coords - expected = data.isel(dim1=(('a', ), pdim1), - dim2=(('b', ), pdim2), - dim3=(('c', ), pdim3)) - expected = expected.assign_coords(a=idx1['a'], b=idx2['b'], - c=idx3['c']) + assert "a" in actual.dims + assert "b" in actual.dims + assert "c" in actual.dims + assert "time" in actual.coords + assert "dim2" in actual.coords + assert "dim3" in actual.coords + expected = data.isel( + dim1=(("a",), pdim1), dim2=(("b",), pdim2), dim3=(("c",), pdim3) + ) + expected = expected.assign_coords(a=idx1["a"], b=idx2["b"], c=idx3["c"]) assert_identical(actual, expected) - idx1 = DataArray(pdim1, dims=['a'], coords={'a': np.random.randn(3)}) - idx2 = DataArray(pdim2, dims=['a']) - idx3 = DataArray(pdim3, dims=['a']) + idx1 = DataArray(pdim1, dims=["a"], coords={"a": np.random.randn(3)}) + idx2 = DataArray(pdim2, dims=["a"]) + idx3 = DataArray(pdim3, dims=["a"]) # Should work with DataArray actual = data.isel(dim1=idx1, dim2=idx2, dim3=idx3) - assert 'a' in actual.dims - assert 'time' in actual.coords - assert 'dim2' in actual.coords - assert 'dim3' in actual.coords - expected = data.isel(dim1=(('a', ), pdim1), - dim2=(('a', ), pdim2), - dim3=(('a', ), pdim3)) - expected = expected.assign_coords(a=idx1['a']) + assert "a" in actual.dims + assert "time" in actual.coords + assert "dim2" in actual.coords + assert "dim3" in actual.coords + expected = data.isel( + dim1=(("a",), pdim1), dim2=(("a",), pdim2), dim3=(("a",), pdim3) + ) + expected = expected.assign_coords(a=idx1["a"]) assert_identical(actual, expected) - actual = data.isel(dim1=(('points', ), pdim1), - dim2=(('points', ), pdim2)) - assert 'points' in actual.dims - assert 'dim3' in actual.dims - assert 'dim3' not in actual.data_vars - np.testing.assert_array_equal(data['dim2'][pdim2], actual['dim2']) + actual = data.isel(dim1=(("points",), pdim1), dim2=(("points",), pdim2)) + assert "points" in actual.dims + assert "dim3" in actual.dims + assert "dim3" not in actual.data_vars + np.testing.assert_array_equal(data["dim2"][pdim2], actual["dim2"]) # test that the order of the indexers doesn't matter - assert_identical(data.isel(dim1=(('points', ), pdim1), - dim2=(('points', ), pdim2)), - data.isel(dim2=(('points', ), pdim2), - dim1=(('points', ), pdim1))) + assert_identical( + data.isel(dim1=(("points",), pdim1), dim2=(("points",), pdim2)), + data.isel(dim2=(("points",), pdim2), dim1=(("points",), pdim1)), + ) # make sure we're raising errors in the right places - with raises_regex(IndexError, - 'Dimensions of indexers mismatch'): - data.isel(dim1=(('points', ), [1, 2]), - dim2=(('points', ), [1, 2, 3])) - with raises_regex(TypeError, 'cannot use a Dataset'): - data.isel(dim1=Dataset({'points': [1, 2]})) + with raises_regex(IndexError, "Dimensions of indexers mismatch"): + data.isel(dim1=(("points",), [1, 2]), dim2=(("points",), [1, 2, 3])) + with raises_regex(TypeError, "cannot use a Dataset"): + data.isel(dim1=Dataset({"points": [1, 2]})) # test to be sure we keep around variables that were not indexed - ds = Dataset({'x': [1, 2, 3, 4], 'y': 0}) - actual = ds.isel(x=(('points', ), [0, 1, 2])) - assert_identical(ds['y'], actual['y']) + ds = Dataset({"x": [1, 2, 3, 4], "y": 0}) + actual = ds.isel(x=(("points",), [0, 1, 2])) + assert_identical(ds["y"], actual["y"]) # tests using index or DataArray as indexers stations = Dataset() - stations['station'] = (('station', ), ['A', 'B', 'C']) - stations['dim1s'] = (('station', ), [1, 2, 3]) - stations['dim2s'] = (('station', ), [4, 5, 1]) - - actual = data.isel(dim1=stations['dim1s'], - dim2=stations['dim2s']) - assert 'station' in actual.coords - assert 'station' in actual.dims - assert_identical(actual['station'].drop(['dim2']), - stations['station']) - - with raises_regex(ValueError, 'conflicting values for '): - data.isel(dim1=DataArray([0, 1, 2], dims='station', - coords={'station': [0, 1, 2]}), - dim2=DataArray([0, 1, 2], dims='station', - coords={'station': [0, 1, 3]})) + stations["station"] = (("station",), ["A", "B", "C"]) + stations["dim1s"] = (("station",), [1, 2, 3]) + stations["dim2s"] = (("station",), [4, 5, 1]) + + actual = data.isel(dim1=stations["dim1s"], dim2=stations["dim2s"]) + assert "station" in actual.coords + assert "station" in actual.dims + assert_identical(actual["station"].drop(["dim2"]), stations["station"]) + + with raises_regex(ValueError, "conflicting values for "): + data.isel( + dim1=DataArray( + [0, 1, 2], dims="station", coords={"station": [0, 1, 2]} + ), + dim2=DataArray( + [0, 1, 2], dims="station", coords={"station": [0, 1, 3]} + ), + ) # multi-dimensional selection stations = Dataset() - stations['a'] = (('a', ), ['A', 'B', 'C']) - stations['b'] = (('b', ), [0, 1]) - stations['dim1s'] = (('a', 'b'), [[1, 2], [2, 3], [3, 4]]) - stations['dim2s'] = (('a', ), [4, 5, 1]) - actual = data.isel(dim1=stations['dim1s'], dim2=stations['dim2s']) - assert 'a' in actual.coords - assert 'a' in actual.dims - assert 'b' in actual.coords - assert 'b' in actual.dims - assert 'dim2' in actual.coords - assert 'a' in actual['dim2'].dims - - assert_identical(actual['a'].drop(['dim2']), - stations['a']) - assert_identical(actual['b'], stations['b']) - expected_var1 = data['var1'].variable[stations['dim1s'].variable, - stations['dim2s'].variable] - expected_var2 = data['var2'].variable[stations['dim1s'].variable, - stations['dim2s'].variable] - expected_var3 = data['var3'].variable[slice(None), - stations['dim1s'].variable] - assert_equal(actual['a'].drop('dim2'), stations['a']) - assert_array_equal(actual['var1'], expected_var1) - assert_array_equal(actual['var2'], expected_var2) - assert_array_equal(actual['var3'], expected_var3) + stations["a"] = (("a",), ["A", "B", "C"]) + stations["b"] = (("b",), [0, 1]) + stations["dim1s"] = (("a", "b"), [[1, 2], [2, 3], [3, 4]]) + stations["dim2s"] = (("a",), [4, 5, 1]) + actual = data.isel(dim1=stations["dim1s"], dim2=stations["dim2s"]) + assert "a" in actual.coords + assert "a" in actual.dims + assert "b" in actual.coords + assert "b" in actual.dims + assert "dim2" in actual.coords + assert "a" in actual["dim2"].dims + + assert_identical(actual["a"].drop(["dim2"]), stations["a"]) + assert_identical(actual["b"], stations["b"]) + expected_var1 = data["var1"].variable[ + stations["dim1s"].variable, stations["dim2s"].variable + ] + expected_var2 = data["var2"].variable[ + stations["dim1s"].variable, stations["dim2s"].variable + ] + expected_var3 = data["var3"].variable[slice(None), stations["dim1s"].variable] + assert_equal(actual["a"].drop("dim2"), stations["a"]) + assert_array_equal(actual["var1"], expected_var1) + assert_array_equal(actual["var2"], expected_var2) + assert_array_equal(actual["var3"], expected_var3) def test_isel_dataarray(self): """ Test for indexing by DataArray """ data = create_test_data() # indexing with DataArray with same-name coordinates. - indexing_da = DataArray(np.arange(1, 4), dims=['dim1'], - coords={'dim1': np.random.randn(3)}) + indexing_da = DataArray( + np.arange(1, 4), dims=["dim1"], coords={"dim1": np.random.randn(3)} + ) actual = data.isel(dim1=indexing_da) - assert_identical(indexing_da['dim1'], actual['dim1']) - assert_identical(data['dim2'], actual['dim2']) + assert_identical(indexing_da["dim1"], actual["dim1"]) + assert_identical(data["dim2"], actual["dim2"]) # Conflict in the dimension coordinate - indexing_da = DataArray(np.arange(1, 4), dims=['dim2'], - coords={'dim2': np.random.randn(3)}) + indexing_da = DataArray( + np.arange(1, 4), dims=["dim2"], coords={"dim2": np.random.randn(3)} + ) with raises_regex(IndexError, "dimension coordinate 'dim2'"): actual = data.isel(dim2=indexing_da) # Also the case for DataArray with raises_regex(IndexError, "dimension coordinate 'dim2'"): - actual = data['var2'].isel(dim2=indexing_da) + actual = data["var2"].isel(dim2=indexing_da) with raises_regex(IndexError, "dimension coordinate 'dim2'"): - data['dim2'].isel(dim2=indexing_da) + data["dim2"].isel(dim2=indexing_da) # same name coordinate which does not conflict - indexing_da = DataArray(np.arange(1, 4), dims=['dim2'], - coords={'dim2': data['dim2'].values[1:4]}) + indexing_da = DataArray( + np.arange(1, 4), dims=["dim2"], coords={"dim2": data["dim2"].values[1:4]} + ) actual = data.isel(dim2=indexing_da) - assert_identical(actual['dim2'], indexing_da['dim2']) + assert_identical(actual["dim2"], indexing_da["dim2"]) # Silently drop conflicted (non-dimensional) coordinate of indexer - indexing_da = DataArray(np.arange(1, 4), dims=['dim2'], - coords={'dim2': data['dim2'].values[1:4], - 'numbers': ('dim2', np.arange(2, 5))}) + indexing_da = DataArray( + np.arange(1, 4), + dims=["dim2"], + coords={ + "dim2": data["dim2"].values[1:4], + "numbers": ("dim2", np.arange(2, 5)), + }, + ) actual = data.isel(dim2=indexing_da) - assert_identical(actual['numbers'], data['numbers']) + assert_identical(actual["numbers"], data["numbers"]) # boolean data array with coordinate with the same name - indexing_da = DataArray(np.arange(1, 10), dims=['dim2'], - coords={'dim2': data['dim2'].values}) - indexing_da = (indexing_da < 3) + indexing_da = DataArray( + np.arange(1, 10), dims=["dim2"], coords={"dim2": data["dim2"].values} + ) + indexing_da = indexing_da < 3 actual = data.isel(dim2=indexing_da) - assert_identical(actual['dim2'], data['dim2'][:2]) + assert_identical(actual["dim2"], data["dim2"][:2]) # boolean data array with non-dimensioncoordinate - indexing_da = DataArray(np.arange(1, 10), dims=['dim2'], - coords={'dim2': data['dim2'].values, - 'non_dim': (('dim2', ), - np.random.randn(9)), - 'non_dim2': 0}) - indexing_da = (indexing_da < 3) + indexing_da = DataArray( + np.arange(1, 10), + dims=["dim2"], + coords={ + "dim2": data["dim2"].values, + "non_dim": (("dim2",), np.random.randn(9)), + "non_dim2": 0, + }, + ) + indexing_da = indexing_da < 3 actual = data.isel(dim2=indexing_da) assert_identical( - actual['dim2'].drop('non_dim').drop('non_dim2'), data['dim2'][:2]) - assert_identical( - actual['non_dim'], indexing_da['non_dim'][:2]) - assert_identical( - actual['non_dim2'], indexing_da['non_dim2']) + actual["dim2"].drop("non_dim").drop("non_dim2"), data["dim2"][:2] + ) + assert_identical(actual["non_dim"], indexing_da["non_dim"][:2]) + assert_identical(actual["non_dim2"], indexing_da["non_dim2"]) # non-dimension coordinate will be also attached - indexing_da = DataArray(np.arange(1, 4), dims=['dim2'], - coords={'non_dim': (('dim2', ), - np.random.randn(3))}) + indexing_da = DataArray( + np.arange(1, 4), + dims=["dim2"], + coords={"non_dim": (("dim2",), np.random.randn(3))}, + ) actual = data.isel(dim2=indexing_da) - assert 'non_dim' in actual - assert 'non_dim' in actual.coords + assert "non_dim" in actual + assert "non_dim" in actual.coords # Index by a scalar DataArray - indexing_da = DataArray(3, dims=[], coords={'station': 2}) + indexing_da = DataArray(3, dims=[], coords={"station": 2}) actual = data.isel(dim2=indexing_da) - assert 'station' in actual - actual = data.isel(dim2=indexing_da['station']) - assert 'station' in actual + assert "station" in actual + actual = data.isel(dim2=indexing_da["station"]) + assert "station" in actual # indexer generated from coordinates - indexing_ds = Dataset({}, coords={'dim2': [0, 1, 2]}) - with raises_regex( - IndexError, "dimension coordinate 'dim2'"): - actual = data.isel(dim2=indexing_ds['dim2']) + indexing_ds = Dataset({}, coords={"dim2": [0, 1, 2]}) + with raises_regex(IndexError, "dimension coordinate 'dim2'"): + actual = data.isel(dim2=indexing_ds["dim2"]) def test_sel(self): data = create_test_data() - int_slicers = {'dim1': slice(None, None, 2), - 'dim2': slice(2), - 'dim3': slice(3)} - loc_slicers = {'dim1': slice(None, None, 2), - 'dim2': slice(0, 0.5), - 'dim3': slice('a', 'c')} - assert_equal(data.isel(**int_slicers), - data.sel(**loc_slicers)) - data['time'] = ('time', pd.date_range('2000-01-01', periods=20)) - assert_equal(data.isel(time=0), - data.sel(time='2000-01-01')) - assert_equal(data.isel(time=slice(10)), - data.sel(time=slice('2000-01-01', - '2000-01-10'))) - assert_equal(data, data.sel(time=slice('1999', '2005'))) - times = pd.date_range('2000-01-01', periods=3) - assert_equal(data.isel(time=slice(3)), - data.sel(time=times)) - assert_equal(data.isel(time=slice(3)), - data.sel(time=(data['time.dayofyear'] <= 3))) - - td = pd.to_timedelta(np.arange(3), unit='days') - data = Dataset({'x': ('td', np.arange(3)), 'td': td}) + int_slicers = {"dim1": slice(None, None, 2), "dim2": slice(2), "dim3": slice(3)} + loc_slicers = { + "dim1": slice(None, None, 2), + "dim2": slice(0, 0.5), + "dim3": slice("a", "c"), + } + assert_equal(data.isel(**int_slicers), data.sel(**loc_slicers)) + data["time"] = ("time", pd.date_range("2000-01-01", periods=20)) + assert_equal(data.isel(time=0), data.sel(time="2000-01-01")) + assert_equal( + data.isel(time=slice(10)), data.sel(time=slice("2000-01-01", "2000-01-10")) + ) + assert_equal(data, data.sel(time=slice("1999", "2005"))) + times = pd.date_range("2000-01-01", periods=3) + assert_equal(data.isel(time=slice(3)), data.sel(time=times)) + assert_equal( + data.isel(time=slice(3)), data.sel(time=(data["time.dayofyear"] <= 3)) + ) + + td = pd.to_timedelta(np.arange(3), unit="days") + data = Dataset({"x": ("td", np.arange(3)), "td": td}) assert_equal(data, data.sel(td=td)) - assert_equal(data, data.sel(td=slice('3 days'))) - assert_equal(data.isel(td=0), - data.sel(td=pd.Timedelta('0 days'))) - assert_equal(data.isel(td=0), - data.sel(td=pd.Timedelta('0h'))) - assert_equal(data.isel(td=slice(1, 3)), - data.sel(td=slice('1 days', '2 days'))) + assert_equal(data, data.sel(td=slice("3 days"))) + assert_equal(data.isel(td=0), data.sel(td=pd.Timedelta("0 days"))) + assert_equal(data.isel(td=0), data.sel(td=pd.Timedelta("0h"))) + assert_equal(data.isel(td=slice(1, 3)), data.sel(td=slice("1 days", "2 days"))) def test_sel_dataarray(self): data = create_test_data() - ind = DataArray([0.0, 0.5, 1.0], dims=['dim2']) + ind = DataArray([0.0, 0.5, 1.0], dims=["dim2"]) actual = data.sel(dim2=ind) assert_equal(actual, data.isel(dim2=[0, 1, 2])) # with different dimension - ind = DataArray([0.0, 0.5, 1.0], dims=['new_dim']) + ind = DataArray([0.0, 0.5, 1.0], dims=["new_dim"]) actual = data.sel(dim2=ind) - expected = data.isel(dim2=Variable('new_dim', [0, 1, 2])) - assert 'new_dim' in actual.dims + expected = data.isel(dim2=Variable("new_dim", [0, 1, 2])) + assert "new_dim" in actual.dims assert_equal(actual, expected) # Multi-dimensional - ind = DataArray([[0.0], [0.5], [1.0]], dims=['new_dim', 'new_dim2']) + ind = DataArray([[0.0], [0.5], [1.0]], dims=["new_dim", "new_dim2"]) actual = data.sel(dim2=ind) - expected = data.isel(dim2=Variable(('new_dim', 'new_dim2'), - [[0], [1], [2]])) - assert 'new_dim' in actual.dims - assert 'new_dim2' in actual.dims + expected = data.isel(dim2=Variable(("new_dim", "new_dim2"), [[0], [1], [2]])) + assert "new_dim" in actual.dims + assert "new_dim2" in actual.dims assert_equal(actual, expected) # with coordinate - ind = DataArray([0.0, 0.5, 1.0], dims=['new_dim'], - coords={'new_dim': ['a', 'b', 'c']}) + ind = DataArray( + [0.0, 0.5, 1.0], dims=["new_dim"], coords={"new_dim": ["a", "b", "c"]} + ) actual = data.sel(dim2=ind) - expected = data.isel(dim2=[0, 1, 2]).rename({'dim2': 'new_dim'}) - assert 'new_dim' in actual.dims - assert 'new_dim' in actual.coords - assert_equal(actual.drop('new_dim').drop('dim2'), - expected.drop('new_dim')) - assert_equal(actual['new_dim'].drop('dim2'), - ind['new_dim']) + expected = data.isel(dim2=[0, 1, 2]).rename({"dim2": "new_dim"}) + assert "new_dim" in actual.dims + assert "new_dim" in actual.coords + assert_equal(actual.drop("new_dim").drop("dim2"), expected.drop("new_dim")) + assert_equal(actual["new_dim"].drop("dim2"), ind["new_dim"]) # with conflicted coordinate (silently ignored) - ind = DataArray([0.0, 0.5, 1.0], dims=['dim2'], - coords={'dim2': ['a', 'b', 'c']}) + ind = DataArray( + [0.0, 0.5, 1.0], dims=["dim2"], coords={"dim2": ["a", "b", "c"]} + ) actual = data.sel(dim2=ind) expected = data.isel(dim2=[0, 1, 2]) assert_equal(actual, expected) # with conflicted coordinate (silently ignored) - ind = DataArray([0.0, 0.5, 1.0], dims=['new_dim'], - coords={'new_dim': ['a', 'b', 'c'], - 'dim2': 3}) + ind = DataArray( + [0.0, 0.5, 1.0], + dims=["new_dim"], + coords={"new_dim": ["a", "b", "c"], "dim2": 3}, + ) actual = data.sel(dim2=ind) - assert_equal(actual['new_dim'].drop('dim2'), - ind['new_dim'].drop('dim2')) + assert_equal(actual["new_dim"].drop("dim2"), ind["new_dim"].drop("dim2")) expected = data.isel(dim2=[0, 1, 2]) - expected['dim2'] = (('new_dim'), expected['dim2'].values) - assert_equal(actual['dim2'].drop('new_dim'), - expected['dim2']) - assert actual['var1'].dims == ('dim1', 'new_dim') + expected["dim2"] = (("new_dim"), expected["dim2"].values) + assert_equal(actual["dim2"].drop("new_dim"), expected["dim2"]) + assert actual["var1"].dims == ("dim1", "new_dim") # with non-dimensional coordinate - ind = DataArray([0.0, 0.5, 1.0], dims=['dim2'], - coords={'dim2': ['a', 'b', 'c'], - 'numbers': ('dim2', [0, 1, 2]), - 'new_dim': ('dim2', [1.1, 1.2, 1.3])}) + ind = DataArray( + [0.0, 0.5, 1.0], + dims=["dim2"], + coords={ + "dim2": ["a", "b", "c"], + "numbers": ("dim2", [0, 1, 2]), + "new_dim": ("dim2", [1.1, 1.2, 1.3]), + }, + ) actual = data.sel(dim2=ind) expected = data.isel(dim2=[0, 1, 2]) - assert_equal(actual.drop('new_dim'), expected) - assert np.allclose(actual['new_dim'].values, ind['new_dim'].values) + assert_equal(actual.drop("new_dim"), expected) + assert np.allclose(actual["new_dim"].values, ind["new_dim"].values) def test_sel_dataarray_mindex(self): - midx = pd.MultiIndex.from_product([list('abc'), [0, 1]], - names=('one', 'two')) - mds = xr.Dataset({'var': (('x', 'y'), np.random.rand(6, 3))}, - coords={'x': midx, 'y': range(3)}) - - actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims='x')) - actual_sel = mds.sel(x=DataArray(mds.indexes['x'][:3], dims='x')) - assert actual_isel['x'].dims == ('x', ) - assert actual_sel['x'].dims == ('x', ) + midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + mds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(6, 3))}, + coords={"x": midx, "y": range(3)}, + ) + + actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="x")) + actual_sel = mds.sel(x=DataArray(mds.indexes["x"][:3], dims="x")) + assert actual_isel["x"].dims == ("x",) + assert actual_sel["x"].dims == ("x",) assert_identical(actual_isel, actual_sel) - actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims='z')) - actual_sel = mds.sel(x=Variable('z', mds.indexes['x'][:3])) - assert actual_isel['x'].dims == ('z', ) - assert actual_sel['x'].dims == ('z', ) + actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="z")) + actual_sel = mds.sel(x=Variable("z", mds.indexes["x"][:3])) + assert actual_isel["x"].dims == ("z",) + assert actual_sel["x"].dims == ("z",) assert_identical(actual_isel, actual_sel) # with coordinate - actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims='z', - coords={'z': [0, 1, 2]})) - actual_sel = mds.sel(x=xr.DataArray(mds.indexes['x'][:3], dims='z', - coords={'z': [0, 1, 2]})) - assert actual_isel['x'].dims == ('z', ) - assert actual_sel['x'].dims == ('z', ) + actual_isel = mds.isel( + x=xr.DataArray(np.arange(3), dims="z", coords={"z": [0, 1, 2]}) + ) + actual_sel = mds.sel( + x=xr.DataArray(mds.indexes["x"][:3], dims="z", coords={"z": [0, 1, 2]}) + ) + assert actual_isel["x"].dims == ("z",) + assert actual_sel["x"].dims == ("z",) assert_identical(actual_isel, actual_sel) # Vectorized indexing with level-variables raises an error - with raises_regex(ValueError, 'Vectorized selection is '): - mds.sel(one=['a', 'b']) + with raises_regex(ValueError, "Vectorized selection is "): + mds.sel(one=["a", "b"]) - with raises_regex(ValueError, 'Vectorized selection is ' - 'not available along MultiIndex variable:' - ' x'): - mds.sel(x=xr.DataArray([np.array(midx[:2]), np.array(midx[-2:])], - dims=['a', 'b'])) + with raises_regex( + ValueError, + "Vectorized selection is " "not available along MultiIndex variable:" " x", + ): + mds.sel( + x=xr.DataArray( + [np.array(midx[:2]), np.array(midx[-2:])], dims=["a", "b"] + ) + ) def test_sel_drop(self): - data = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]}) - expected = Dataset({'foo': 1}) + data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) + expected = Dataset({"foo": 1}) selected = data.sel(x=0, drop=True) assert_identical(expected, selected) - expected = Dataset({'foo': 1}, {'x': 0}) + expected = Dataset({"foo": 1}, {"x": 0}) selected = data.sel(x=0, drop=False) assert_identical(expected, selected) - data = Dataset({'foo': ('x', [1, 2, 3])}) - expected = Dataset({'foo': 1}) + data = Dataset({"foo": ("x", [1, 2, 3])}) + expected = Dataset({"foo": 1}) selected = data.sel(x=0, drop=True) assert_identical(expected, selected) def test_isel_drop(self): - data = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]}) - expected = Dataset({'foo': 1}) + data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) + expected = Dataset({"foo": 1}) selected = data.isel(x=0, drop=True) assert_identical(expected, selected) - expected = Dataset({'foo': 1}, {'x': 0}) + expected = Dataset({"foo": 1}, {"x": 0}) selected = data.isel(x=0, drop=False) assert_identical(expected, selected) @@ -1301,76 +1394,73 @@ def test_isel_points(self): pdim1 = [1, 2, 3] pdim2 = [4, 5, 1] pdim3 = [1, 2, 3] - actual = data.isel_points(dim1=pdim1, dim2=pdim2, dim3=pdim3, - dim='test_coord') - assert 'test_coord' in actual.dims - assert actual.coords['test_coord'].shape == (len(pdim1), ) + actual = data.isel_points(dim1=pdim1, dim2=pdim2, dim3=pdim3, dim="test_coord") + assert "test_coord" in actual.dims + assert actual.coords["test_coord"].shape == (len(pdim1),) actual = data.isel_points(dim1=pdim1, dim2=pdim2) - assert 'points' in actual.dims - assert 'dim3' in actual.dims - assert 'dim3' not in actual.data_vars - np.testing.assert_array_equal(data['dim2'][pdim2], actual['dim2']) + assert "points" in actual.dims + assert "dim3" in actual.dims + assert "dim3" not in actual.data_vars + np.testing.assert_array_equal(data["dim2"][pdim2], actual["dim2"]) # test that the order of the indexers doesn't matter - assert_identical(data.isel_points(dim1=pdim1, dim2=pdim2), - data.isel_points(dim2=pdim2, dim1=pdim1)) + assert_identical( + data.isel_points(dim1=pdim1, dim2=pdim2), + data.isel_points(dim2=pdim2, dim1=pdim1), + ) # make sure we're raising errors in the right places - with raises_regex(ValueError, - 'All indexers must be the same length'): + with raises_regex(ValueError, "All indexers must be the same length"): data.isel_points(dim1=[1, 2], dim2=[1, 2, 3]) - with raises_regex(ValueError, - 'dimension bad_key does not exist'): + with raises_regex(ValueError, "dimension bad_key does not exist"): data.isel_points(bad_key=[1, 2]) - with raises_regex(TypeError, 'Indexers must be integers'): + with raises_regex(TypeError, "Indexers must be integers"): data.isel_points(dim1=[1.5, 2.2]) - with raises_regex(TypeError, 'Indexers must be integers'): + with raises_regex(TypeError, "Indexers must be integers"): data.isel_points(dim1=[1, 2, 3], dim2=slice(3)) - with raises_regex(ValueError, - 'Indexers must be 1 dimensional'): + with raises_regex(ValueError, "Indexers must be 1 dimensional"): data.isel_points(dim1=1, dim2=2) - with raises_regex(ValueError, - 'Existing dimension names are not valid'): - data.isel_points(dim1=[1, 2], dim2=[1, 2], dim='dim2') + with raises_regex(ValueError, "Existing dimension names are not valid"): + data.isel_points(dim1=[1, 2], dim2=[1, 2], dim="dim2") # test to be sure we keep around variables that were not indexed - ds = Dataset({'x': [1, 2, 3, 4], 'y': 0}) + ds = Dataset({"x": [1, 2, 3, 4], "y": 0}) actual = ds.isel_points(x=[0, 1, 2]) - assert_identical(ds['y'], actual['y']) + assert_identical(ds["y"], actual["y"]) # tests using index or DataArray as a dim stations = Dataset() - stations['station'] = ('station', ['A', 'B', 'C']) - stations['dim1s'] = ('station', [1, 2, 3]) - stations['dim2s'] = ('station', [4, 5, 1]) - - actual = data.isel_points(dim1=stations['dim1s'], - dim2=stations['dim2s'], - dim=stations['station']) - assert 'station' in actual.coords - assert 'station' in actual.dims - assert_identical(actual['station'].drop(['dim2']), - stations['station']) + stations["station"] = ("station", ["A", "B", "C"]) + stations["dim1s"] = ("station", [1, 2, 3]) + stations["dim2s"] = ("station", [4, 5, 1]) + + actual = data.isel_points( + dim1=stations["dim1s"], dim2=stations["dim2s"], dim=stations["station"] + ) + assert "station" in actual.coords + assert "station" in actual.dims + assert_identical(actual["station"].drop(["dim2"]), stations["station"]) # make sure we get the default 'points' coordinate when passed a list - actual = data.isel_points(dim1=stations['dim1s'], - dim2=stations['dim2s'], - dim=['A', 'B', 'C']) - assert 'points' in actual.coords - assert actual.coords['points'].values.tolist() == ['A', 'B', 'C'] + actual = data.isel_points( + dim1=stations["dim1s"], dim2=stations["dim2s"], dim=["A", "B", "C"] + ) + assert "points" in actual.coords + assert actual.coords["points"].values.tolist() == ["A", "B", "C"] # test index - actual = data.isel_points(dim1=stations['dim1s'].values, - dim2=stations['dim2s'].values, - dim=pd.Index(['A', 'B', 'C'], - name='letters')) - assert 'letters' in actual.coords + actual = data.isel_points( + dim1=stations["dim1s"].values, + dim2=stations["dim2s"].values, + dim=pd.Index(["A", "B", "C"], name="letters"), + ) + assert "letters" in actual.coords # can pass a numpy array - data.isel_points(dim1=stations['dim1s'], - dim2=stations['dim2s'], - dim=np.array([4, 5, 6])) + data.isel_points( + dim1=stations["dim1s"], dim2=stations["dim2s"], dim=np.array([4, 5, 6]) + ) @pytest.mark.filterwarnings("ignore:Dataset.sel_points") @pytest.mark.filterwarnings("ignore:Dataset.isel_points") @@ -1378,222 +1468,236 @@ def test_sel_points(self): data = create_test_data() # add in a range() index - data['dim1'] = data.dim1 + data["dim1"] = data.dim1 pdim1 = [1, 2, 3] pdim2 = [4, 5, 1] pdim3 = [1, 2, 3] - expected = data.isel_points(dim1=pdim1, dim2=pdim2, dim3=pdim3, - dim='test_coord') - actual = data.sel_points(dim1=data.dim1[pdim1], dim2=data.dim2[pdim2], - dim3=data.dim3[pdim3], dim='test_coord') + expected = data.isel_points( + dim1=pdim1, dim2=pdim2, dim3=pdim3, dim="test_coord" + ) + actual = data.sel_points( + dim1=data.dim1[pdim1], + dim2=data.dim2[pdim2], + dim3=data.dim3[pdim3], + dim="test_coord", + ) assert_identical(expected, actual) - data = Dataset({'foo': (('x', 'y'), np.arange(9).reshape(3, 3))}) - expected = Dataset({'foo': ('points', [0, 4, 8])}) + data = Dataset({"foo": (("x", "y"), np.arange(9).reshape(3, 3))}) + expected = Dataset({"foo": ("points", [0, 4, 8])}) actual = data.sel_points(x=[0, 1, 2], y=[0, 1, 2]) assert_identical(expected, actual) - data.coords.update({'x': [0, 1, 2], 'y': [0, 1, 2]}) - expected.coords.update({'x': ('points', [0, 1, 2]), - 'y': ('points', [0, 1, 2])}) - actual = data.sel_points(x=[0.1, 1.1, 2.5], y=[0, 1.2, 2.0], - method='pad') + data.coords.update({"x": [0, 1, 2], "y": [0, 1, 2]}) + expected.coords.update({"x": ("points", [0, 1, 2]), "y": ("points", [0, 1, 2])}) + actual = data.sel_points(x=[0.1, 1.1, 2.5], y=[0, 1.2, 2.0], method="pad") assert_identical(expected, actual) with pytest.raises(KeyError): - data.sel_points(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) + data.sel_points(x=[2.5], y=[2.0], method="pad", tolerance=1e-3) - @pytest.mark.filterwarnings('ignore::DeprecationWarning') + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_sel_fancy(self): data = create_test_data() # add in a range() index - data['dim1'] = data.dim1 + data["dim1"] = data.dim1 pdim1 = [1, 2, 3] pdim2 = [4, 5, 1] pdim3 = [1, 2, 3] - expected = data.isel(dim1=Variable(('test_coord', ), pdim1), - dim2=Variable(('test_coord', ), pdim2), - dim3=Variable(('test_coord'), pdim3)) - actual = data.sel(dim1=Variable(('test_coord', ), data.dim1[pdim1]), - dim2=Variable(('test_coord', ), data.dim2[pdim2]), - dim3=Variable(('test_coord', ), data.dim3[pdim3])) + expected = data.isel( + dim1=Variable(("test_coord",), pdim1), + dim2=Variable(("test_coord",), pdim2), + dim3=Variable(("test_coord"), pdim3), + ) + actual = data.sel( + dim1=Variable(("test_coord",), data.dim1[pdim1]), + dim2=Variable(("test_coord",), data.dim2[pdim2]), + dim3=Variable(("test_coord",), data.dim3[pdim3]), + ) assert_identical(expected, actual) # DataArray Indexer - idx_t = DataArray(data['time'][[3, 2, 1]].values, dims=['a'], - coords={'a': ['a', 'b', 'c']}) - idx_2 = DataArray(data['dim2'][[3, 2, 1]].values, dims=['a'], - coords={'a': ['a', 'b', 'c']}) - idx_3 = DataArray(data['dim3'][[3, 2, 1]].values, dims=['a'], - coords={'a': ['a', 'b', 'c']}) + idx_t = DataArray( + data["time"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + idx_2 = DataArray( + data["dim2"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + idx_3 = DataArray( + data["dim3"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) actual = data.sel(time=idx_t, dim2=idx_2, dim3=idx_3) - expected = data.isel(time=Variable(('a', ), [3, 2, 1]), - dim2=Variable(('a', ), [3, 2, 1]), - dim3=Variable(('a', ), [3, 2, 1])) - expected = expected.assign_coords(a=idx_t['a']) - assert_identical(expected, actual) - - idx_t = DataArray(data['time'][[3, 2, 1]].values, dims=['a'], - coords={'a': ['a', 'b', 'c']}) - idx_2 = DataArray(data['dim2'][[2, 1, 3]].values, dims=['b'], - coords={'b': [0, 1, 2]}) - idx_3 = DataArray(data['dim3'][[1, 2, 1]].values, dims=['c'], - coords={'c': [0.0, 1.1, 2.2]}) + expected = data.isel( + time=Variable(("a",), [3, 2, 1]), + dim2=Variable(("a",), [3, 2, 1]), + dim3=Variable(("a",), [3, 2, 1]), + ) + expected = expected.assign_coords(a=idx_t["a"]) + assert_identical(expected, actual) + + idx_t = DataArray( + data["time"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + idx_2 = DataArray( + data["dim2"][[2, 1, 3]].values, dims=["b"], coords={"b": [0, 1, 2]} + ) + idx_3 = DataArray( + data["dim3"][[1, 2, 1]].values, dims=["c"], coords={"c": [0.0, 1.1, 2.2]} + ) actual = data.sel(time=idx_t, dim2=idx_2, dim3=idx_3) - expected = data.isel(time=Variable(('a', ), [3, 2, 1]), - dim2=Variable(('b', ), [2, 1, 3]), - dim3=Variable(('c', ), [1, 2, 1])) - expected = expected.assign_coords(a=idx_t['a'], b=idx_2['b'], - c=idx_3['c']) + expected = data.isel( + time=Variable(("a",), [3, 2, 1]), + dim2=Variable(("b",), [2, 1, 3]), + dim3=Variable(("c",), [1, 2, 1]), + ) + expected = expected.assign_coords(a=idx_t["a"], b=idx_2["b"], c=idx_3["c"]) assert_identical(expected, actual) # test from sel_points - data = Dataset({'foo': (('x', 'y'), np.arange(9).reshape(3, 3))}) - data.coords.update({'x': [0, 1, 2], 'y': [0, 1, 2]}) + data = Dataset({"foo": (("x", "y"), np.arange(9).reshape(3, 3))}) + data.coords.update({"x": [0, 1, 2], "y": [0, 1, 2]}) - expected = Dataset({'foo': ('points', [0, 4, 8])}, - coords={'x': Variable(('points', ), [0, 1, 2]), - 'y': Variable(('points', ), [0, 1, 2])}) - actual = data.sel(x=Variable(('points', ), [0, 1, 2]), - y=Variable(('points', ), [0, 1, 2])) + expected = Dataset( + {"foo": ("points", [0, 4, 8])}, + coords={ + "x": Variable(("points",), [0, 1, 2]), + "y": Variable(("points",), [0, 1, 2]), + }, + ) + actual = data.sel( + x=Variable(("points",), [0, 1, 2]), y=Variable(("points",), [0, 1, 2]) + ) assert_identical(expected, actual) - expected.coords.update({'x': ('points', [0, 1, 2]), - 'y': ('points', [0, 1, 2])}) - actual = data.sel(x=Variable(('points', ), [0.1, 1.1, 2.5]), - y=Variable(('points', ), [0, 1.2, 2.0]), - method='pad') + expected.coords.update({"x": ("points", [0, 1, 2]), "y": ("points", [0, 1, 2])}) + actual = data.sel( + x=Variable(("points",), [0.1, 1.1, 2.5]), + y=Variable(("points",), [0, 1.2, 2.0]), + method="pad", + ) assert_identical(expected, actual) - idx_x = DataArray([0, 1, 2], dims=['a'], coords={'a': ['a', 'b', 'c']}) - idx_y = DataArray([0, 2, 1], dims=['b'], coords={'b': [0, 3, 6]}) - expected_ary = data['foo'][[0, 1, 2], [0, 2, 1]] + idx_x = DataArray([0, 1, 2], dims=["a"], coords={"a": ["a", "b", "c"]}) + idx_y = DataArray([0, 2, 1], dims=["b"], coords={"b": [0, 3, 6]}) + expected_ary = data["foo"][[0, 1, 2], [0, 2, 1]] actual = data.sel(x=idx_x, y=idx_y) - assert_array_equal(expected_ary, actual['foo']) - assert_identical(actual['a'].drop('x'), idx_x['a']) - assert_identical(actual['b'].drop('y'), idx_y['b']) + assert_array_equal(expected_ary, actual["foo"]) + assert_identical(actual["a"].drop("x"), idx_x["a"]) + assert_identical(actual["b"].drop("y"), idx_y["b"]) with pytest.raises(KeyError): - data.sel(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) + data.sel(x=[2.5], y=[2.0], method="pad", tolerance=1e-3) def test_sel_method(self): data = create_test_data() expected = data.sel(dim2=1) - actual = data.sel(dim2=0.95, method='nearest') + actual = data.sel(dim2=0.95, method="nearest") assert_identical(expected, actual) - actual = data.sel(dim2=0.95, method='nearest', tolerance=1) + actual = data.sel(dim2=0.95, method="nearest", tolerance=1) assert_identical(expected, actual) with pytest.raises(KeyError): - actual = data.sel(dim2=np.pi, method='nearest', tolerance=0) + actual = data.sel(dim2=np.pi, method="nearest", tolerance=0) expected = data.sel(dim2=[1.5]) - actual = data.sel(dim2=[1.45], method='backfill') + actual = data.sel(dim2=[1.45], method="backfill") assert_identical(expected, actual) - with raises_regex(NotImplementedError, 'slice objects'): - data.sel(dim2=slice(1, 3), method='ffill') + with raises_regex(NotImplementedError, "slice objects"): + data.sel(dim2=slice(1, 3), method="ffill") - with raises_regex(TypeError, '``method``'): + with raises_regex(TypeError, "``method``"): # this should not pass silently data.sel(method=data) # cannot pass method if there is no associated coordinate - with raises_regex(ValueError, 'cannot supply'): - data.sel(dim1=0, method='nearest') + with raises_regex(ValueError, "cannot supply"): + data.sel(dim1=0, method="nearest") def test_loc(self): data = create_test_data() - expected = data.sel(dim3='a') - actual = data.loc[dict(dim3='a')] + expected = data.sel(dim3="a") + actual = data.loc[dict(dim3="a")] assert_identical(expected, actual) - with raises_regex(TypeError, 'can only lookup dict'): - data.loc['a'] + with raises_regex(TypeError, "can only lookup dict"): + data.loc["a"] with pytest.raises(TypeError): - data.loc[dict(dim3='a')] = 0 + data.loc[dict(dim3="a")] = 0 def test_selection_multiindex(self): - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], - names=('one', 'two', 'three')) - mdata = Dataset(data_vars={'var': ('x', range(8))}, - coords={'x': mindex}) + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + mdata = Dataset(data_vars={"var": ("x", range(8))}, coords={"x": mindex}) - def test_sel(lab_indexer, pos_indexer, replaced_idx=False, - renamed_dim=None): + def test_sel(lab_indexer, pos_indexer, replaced_idx=False, renamed_dim=None): ds = mdata.sel(x=lab_indexer) expected_ds = mdata.isel(x=pos_indexer) if not replaced_idx: assert_identical(ds, expected_ds) else: if renamed_dim: - assert ds['var'].dims[0] == renamed_dim - ds = ds.rename({renamed_dim: 'x'}) - assert_identical(ds['var'].variable, - expected_ds['var'].variable) - assert not ds['x'].equals(expected_ds['x']) - - test_sel(('a', 1, -1), 0) - test_sel(('b', 2, -2), -1) - test_sel(('a', 1), [0, 1], replaced_idx=True, renamed_dim='three') - test_sel(('a',), range(4), replaced_idx=True) - test_sel('a', range(4), replaced_idx=True) - test_sel([('a', 1, -1), ('b', 2, -2)], [0, 7]) - test_sel(slice('a', 'b'), range(8)) - test_sel(slice(('a', 1), ('b', 1)), range(6)) - test_sel({'one': 'a', 'two': 1, 'three': -1}, 0) - test_sel({'one': 'a', 'two': 1}, [0, 1], replaced_idx=True, - renamed_dim='three') - test_sel({'one': 'a'}, range(4), replaced_idx=True) - - assert_identical(mdata.loc[{'x': {'one': 'a'}}], - mdata.sel(x={'one': 'a'})) - assert_identical(mdata.loc[{'x': 'a'}], - mdata.sel(x='a')) - assert_identical(mdata.loc[{'x': ('a', 1)}], - mdata.sel(x=('a', 1))) - assert_identical(mdata.loc[{'x': ('a', 1, -1)}], - mdata.sel(x=('a', 1, -1))) - - assert_identical(mdata.sel(x={'one': 'a', 'two': 1}), - mdata.sel(one='a', two=1)) + assert ds["var"].dims[0] == renamed_dim + ds = ds.rename({renamed_dim: "x"}) + assert_identical(ds["var"].variable, expected_ds["var"].variable) + assert not ds["x"].equals(expected_ds["x"]) + + test_sel(("a", 1, -1), 0) + test_sel(("b", 2, -2), -1) + test_sel(("a", 1), [0, 1], replaced_idx=True, renamed_dim="three") + test_sel(("a",), range(4), replaced_idx=True) + test_sel("a", range(4), replaced_idx=True) + test_sel([("a", 1, -1), ("b", 2, -2)], [0, 7]) + test_sel(slice("a", "b"), range(8)) + test_sel(slice(("a", 1), ("b", 1)), range(6)) + test_sel({"one": "a", "two": 1, "three": -1}, 0) + test_sel({"one": "a", "two": 1}, [0, 1], replaced_idx=True, renamed_dim="three") + test_sel({"one": "a"}, range(4), replaced_idx=True) + + assert_identical(mdata.loc[{"x": {"one": "a"}}], mdata.sel(x={"one": "a"})) + assert_identical(mdata.loc[{"x": "a"}], mdata.sel(x="a")) + assert_identical(mdata.loc[{"x": ("a", 1)}], mdata.sel(x=("a", 1))) + assert_identical(mdata.loc[{"x": ("a", 1, -1)}], mdata.sel(x=("a", 1, -1))) + + assert_identical(mdata.sel(x={"one": "a", "two": 1}), mdata.sel(one="a", two=1)) def test_broadcast_like(self): - original1 = DataArray(np.random.randn(5), - [('x', range(5))], name='a').to_dataset() + original1 = DataArray( + np.random.randn(5), [("x", range(5))], name="a" + ).to_dataset() - original2 = DataArray(np.random.randn(6), - [('y', range(6))], name='b') + original2 = DataArray(np.random.randn(6), [("y", range(6))], name="b") expected1, expected2 = broadcast(original1, original2) - assert_identical(original1.broadcast_like(original2), - expected1.transpose('y', 'x')) + assert_identical( + original1.broadcast_like(original2), expected1.transpose("y", "x") + ) - assert_identical(original2.broadcast_like(original1), - expected2) + assert_identical(original2.broadcast_like(original1), expected2) def test_reindex_like(self): data = create_test_data() - data['letters'] = ('dim3', 10 * ['a']) + data["letters"] = ("dim3", 10 * ["a"]) expected = data.isel(dim1=slice(10), time=slice(13)) actual = data.reindex_like(expected) assert_identical(actual, expected) expected = data.copy(deep=True) - expected['dim3'] = ('dim3', list('cdefghijkl')) - expected['var3'][:-2] = expected['var3'][2:].values - expected['var3'][-2:] = np.nan - expected['letters'] = expected['letters'].astype(object) - expected['letters'][-2:] = np.nan - expected['numbers'] = expected['numbers'].astype(float) - expected['numbers'][:-2] = expected['numbers'][2:].values - expected['numbers'][-2:] = np.nan + expected["dim3"] = ("dim3", list("cdefghijkl")) + expected["var3"][:-2] = expected["var3"][2:].values + expected["var3"][-2:] = np.nan + expected["letters"] = expected["letters"].astype(object) + expected["letters"][-2:] = np.nan + expected["numbers"] = expected["numbers"].astype(float) + expected["numbers"][:-2] = expected["numbers"][2:].values + expected["numbers"][-2:] = np.nan actual = data.reindex_like(expected) assert_identical(actual, expected) @@ -1601,56 +1705,56 @@ def test_reindex(self): data = create_test_data() assert_identical(data, data.reindex()) - expected = data.assign_coords(dim1=data['dim1']) - actual = data.reindex(dim1=data['dim1']) + expected = data.assign_coords(dim1=data["dim1"]) + actual = data.reindex(dim1=data["dim1"]) assert_identical(actual, expected) - actual = data.reindex(dim1=data['dim1'].values) + actual = data.reindex(dim1=data["dim1"].values) assert_identical(actual, expected) - actual = data.reindex(dim1=data['dim1'].to_index()) + actual = data.reindex(dim1=data["dim1"].to_index()) assert_identical(actual, expected) - with raises_regex( - ValueError, 'cannot reindex or align along dimension'): - data.reindex(dim1=data['dim1'][:5]) + with raises_regex(ValueError, "cannot reindex or align along dimension"): + data.reindex(dim1=data["dim1"][:5]) expected = data.isel(dim2=slice(5)) - actual = data.reindex(dim2=data['dim2'][:5]) + actual = data.reindex(dim2=data["dim2"][:5]) assert_identical(actual, expected) # test dict-like argument - actual = data.reindex({'dim2': data['dim2']}) + actual = data.reindex({"dim2": data["dim2"]}) expected = data assert_identical(actual, expected) - with raises_regex(ValueError, 'cannot specify both'): - data.reindex({'x': 0}, x=0) - with raises_regex(ValueError, 'dictionary'): - data.reindex('foo') + with raises_regex(ValueError, "cannot specify both"): + data.reindex({"x": 0}, x=0) + with raises_regex(ValueError, "dictionary"): + data.reindex("foo") # invalid dimension - with raises_regex(ValueError, 'invalid reindex dim'): + with raises_regex(ValueError, "invalid reindex dim"): data.reindex(invalid=0) # out of order - expected = data.sel(dim2=data['dim2'][:5:-1]) - actual = data.reindex(dim2=data['dim2'][:5:-1]) + expected = data.sel(dim2=data["dim2"][:5:-1]) + actual = data.reindex(dim2=data["dim2"][:5:-1]) assert_identical(actual, expected) # regression test for #279 - expected = Dataset({'x': ('time', np.random.randn(5))}, - {'time': range(5)}) + expected = Dataset({"x": ("time", np.random.randn(5))}, {"time": range(5)}) time2 = DataArray(np.arange(5), dims="time2") with pytest.warns(FutureWarning): actual = expected.reindex(time=time2) assert_identical(actual, expected) # another regression test - ds = Dataset({'foo': (['x', 'y'], np.zeros((3, 4)))}, - {'x': range(3), 'y': range(4)}) - expected = Dataset({'foo': (['x', 'y'], np.zeros((3, 2)))}, - {'x': [0, 1, 3], 'y': [0, 1]}) - expected['foo'][-1] = np.nan + ds = Dataset( + {"foo": (["x", "y"], np.zeros((3, 4)))}, {"x": range(3), "y": range(4)} + ) + expected = Dataset( + {"foo": (["x", "y"], np.zeros((3, 2)))}, {"x": [0, 1, 3], "y": [0, 1]} + ) + expected["foo"][-1] = np.nan actual = ds.reindex(x=[0, 1, 3], y=[0, 1]) assert_identical(expected, actual) @@ -1659,13 +1763,12 @@ def test_reindex_warning(self): with pytest.warns(FutureWarning) as ws: # DataArray with different dimension raises Future warning - ind = xr.DataArray([0.0, 1.0], dims=['new_dim'], name='ind') + ind = xr.DataArray([0.0, 1.0], dims=["new_dim"], name="ind") data.reindex(dim2=ind) - assert any(["Indexer has dimensions " in - str(w.message) for w in ws]) + assert any(["Indexer has dimensions " in str(w.message) for w in ws]) # Should not warn - ind = xr.DataArray([0.0, 1.0], dims=['dim2'], name='ind') + ind = xr.DataArray([0.0, 1.0], dims=["dim2"], name="ind") with pytest.warns(None) as ws: data.reindex(dim2=ind) assert len(ws) == 0 @@ -1677,271 +1780,320 @@ def test_reindex_variables_copied(self): assert reindexed_data.variables[k] is not data.variables[k] def test_reindex_method(self): - ds = Dataset({'x': ('y', [10, 20]), 'y': [0, 1]}) + ds = Dataset({"x": ("y", [10, 20]), "y": [0, 1]}) y = [-0.5, 0.5, 1.5] - actual = ds.reindex(y=y, method='backfill') - expected = Dataset({'x': ('y', [10, 20, np.nan]), 'y': y}) + actual = ds.reindex(y=y, method="backfill") + expected = Dataset({"x": ("y", [10, 20, np.nan]), "y": y}) assert_identical(expected, actual) - actual = ds.reindex(y=y, method='backfill', tolerance=0.1) - expected = Dataset({'x': ('y', 3 * [np.nan]), 'y': y}) + actual = ds.reindex(y=y, method="backfill", tolerance=0.1) + expected = Dataset({"x": ("y", 3 * [np.nan]), "y": y}) assert_identical(expected, actual) - actual = ds.reindex(y=y, method='pad') - expected = Dataset({'x': ('y', [np.nan, 10, 20]), 'y': y}) + actual = ds.reindex(y=y, method="pad") + expected = Dataset({"x": ("y", [np.nan, 10, 20]), "y": y}) assert_identical(expected, actual) - alt = Dataset({'y': y}) - actual = ds.reindex_like(alt, method='pad') + alt = Dataset({"y": y}) + actual = ds.reindex_like(alt, method="pad") assert_identical(expected, actual) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_reindex_fill_value(self, fill_value): - ds = Dataset({'x': ('y', [10, 20]), 'y': [0, 1]}) + ds = Dataset({"x": ("y", [10, 20]), "y": [0, 1]}) y = [0, 1, 2] actual = ds.reindex(y=y, fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = Dataset({'x': ('y', [10, 20, fill_value]), 'y': y}) + expected = Dataset({"x": ("y", [10, 20, fill_value]), "y": y}) assert_identical(expected, actual) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_reindex_like_fill_value(self, fill_value): - ds = Dataset({'x': ('y', [10, 20]), 'y': [0, 1]}) + ds = Dataset({"x": ("y", [10, 20]), "y": [0, 1]}) y = [0, 1, 2] - alt = Dataset({'y': y}) + alt = Dataset({"y": y}) actual = ds.reindex_like(alt, fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = Dataset({'x': ('y', [10, 20, fill_value]), 'y': y}) + expected = Dataset({"x": ("y", [10, 20, fill_value]), "y": y}) assert_identical(expected, actual) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_align_fill_value(self, fill_value): - x = Dataset({'foo': DataArray([1, 2], dims=['x'], - coords={'x': [1, 2]})}) - y = Dataset({'bar': DataArray([1, 2], dims=['x'], - coords={'x': [1, 3]})}) - x2, y2 = align(x, y, join='outer', fill_value=fill_value) + x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})}) + y = Dataset({"bar": DataArray([1, 2], dims=["x"], coords={"x": [1, 3]})}) + x2, y2 = align(x, y, join="outer", fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan expected_x2 = Dataset( - {'foo': DataArray([1, 2, fill_value], - dims=['x'], - coords={'x': [1, 2, 3]})}) + {"foo": DataArray([1, 2, fill_value], dims=["x"], coords={"x": [1, 2, 3]})} + ) expected_y2 = Dataset( - {'bar': DataArray([1, fill_value, 2], - dims=['x'], - coords={'x': [1, 2, 3]})}) + {"bar": DataArray([1, fill_value, 2], dims=["x"], coords={"x": [1, 2, 3]})} + ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) def test_align(self): left = create_test_data() right = left.copy(deep=True) - right['dim3'] = ('dim3', list('cdefghijkl')) - right['var3'][:-2] = right['var3'][2:].values - right['var3'][-2:] = np.random.randn(*right['var3'][-2:].shape) - right['numbers'][:-2] = right['numbers'][2:].values - right['numbers'][-2:] = -10 + right["dim3"] = ("dim3", list("cdefghijkl")) + right["var3"][:-2] = right["var3"][2:].values + right["var3"][-2:] = np.random.randn(*right["var3"][-2:].shape) + right["numbers"][:-2] = right["numbers"][2:].values + right["numbers"][-2:] = -10 - intersection = list('cdefghij') - union = list('abcdefghijkl') + intersection = list("cdefghij") + union = list("abcdefghijkl") - left2, right2 = align(left, right, join='inner') - assert_array_equal(left2['dim3'], intersection) + left2, right2 = align(left, right, join="inner") + assert_array_equal(left2["dim3"], intersection) assert_identical(left2, right2) - left2, right2 = align(left, right, join='outer') + left2, right2 = align(left, right, join="outer") - assert_array_equal(left2['dim3'], union) - assert_equal(left2['dim3'].variable, right2['dim3'].variable) + assert_array_equal(left2["dim3"], union) + assert_equal(left2["dim3"].variable, right2["dim3"].variable) - assert_identical(left2.sel(dim3=intersection), - right2.sel(dim3=intersection)) - assert np.isnan(left2['var3'][-2:]).all() - assert np.isnan(right2['var3'][:2]).all() + assert_identical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) + assert np.isnan(left2["var3"][-2:]).all() + assert np.isnan(right2["var3"][:2]).all() - left2, right2 = align(left, right, join='left') - assert_equal(left2['dim3'].variable, right2['dim3'].variable) - assert_equal(left2['dim3'].variable, left['dim3'].variable) + left2, right2 = align(left, right, join="left") + assert_equal(left2["dim3"].variable, right2["dim3"].variable) + assert_equal(left2["dim3"].variable, left["dim3"].variable) - assert_identical(left2.sel(dim3=intersection), - right2.sel(dim3=intersection)) - assert np.isnan(right2['var3'][:2]).all() + assert_identical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) + assert np.isnan(right2["var3"][:2]).all() - left2, right2 = align(left, right, join='right') - assert_equal(left2['dim3'].variable, right2['dim3'].variable) - assert_equal(left2['dim3'].variable, right['dim3'].variable) + left2, right2 = align(left, right, join="right") + assert_equal(left2["dim3"].variable, right2["dim3"].variable) + assert_equal(left2["dim3"].variable, right["dim3"].variable) - assert_identical(left2.sel(dim3=intersection), - right2.sel(dim3=intersection)) + assert_identical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) - assert np.isnan(left2['var3'][-2:]).all() + assert np.isnan(left2["var3"][-2:]).all() - with raises_regex(ValueError, 'invalid value for join'): - align(left, right, join='foobar') + with raises_regex(ValueError, "invalid value for join"): + align(left, right, join="foobar") with pytest.raises(TypeError): - align(left, right, foo='bar') + align(left, right, foo="bar") def test_align_exact(self): - left = xr.Dataset(coords={'x': [0, 1]}) - right = xr.Dataset(coords={'x': [1, 2]}) + left = xr.Dataset(coords={"x": [0, 1]}) + right = xr.Dataset(coords={"x": [1, 2]}) - left1, left2 = xr.align(left, left, join='exact') + left1, left2 = xr.align(left, left, join="exact") assert_identical(left1, left) assert_identical(left2, left) - with raises_regex(ValueError, 'indexes .* not equal'): - xr.align(left, right, join='exact') + with raises_regex(ValueError, "indexes .* not equal"): + xr.align(left, right, join="exact") def test_align_exclude(self): - x = Dataset({'foo': DataArray([[1, 2], [3, 4]], dims=['x', 'y'], - coords={'x': [1, 2], 'y': [3, 4]})}) - y = Dataset({'bar': DataArray([[1, 2], [3, 4]], dims=['x', 'y'], - coords={'x': [1, 3], 'y': [5, 6]})}) - x2, y2 = align(x, y, exclude=['y'], join='outer') + x = Dataset( + { + "foo": DataArray( + [[1, 2], [3, 4]], dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4]} + ) + } + ) + y = Dataset( + { + "bar": DataArray( + [[1, 2], [3, 4]], dims=["x", "y"], coords={"x": [1, 3], "y": [5, 6]} + ) + } + ) + x2, y2 = align(x, y, exclude=["y"], join="outer") expected_x2 = Dataset( - {'foo': DataArray([[1, 2], [3, 4], [np.nan, np.nan]], - dims=['x', 'y'], - coords={'x': [1, 2, 3], 'y': [3, 4]})}) + { + "foo": DataArray( + [[1, 2], [3, 4], [np.nan, np.nan]], + dims=["x", "y"], + coords={"x": [1, 2, 3], "y": [3, 4]}, + ) + } + ) expected_y2 = Dataset( - {'bar': DataArray([[1, 2], [np.nan, np.nan], [3, 4]], - dims=['x', 'y'], - coords={'x': [1, 2, 3], 'y': [5, 6]})}) + { + "bar": DataArray( + [[1, 2], [np.nan, np.nan], [3, 4]], + dims=["x", "y"], + coords={"x": [1, 2, 3], "y": [5, 6]}, + ) + } + ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) def test_align_nocopy(self): - x = Dataset({'foo': DataArray([1, 2, 3], coords=[('x', [1, 2, 3])])}) - y = Dataset({'foo': DataArray([1, 2], coords=[('x', [1, 2])])}) + x = Dataset({"foo": DataArray([1, 2, 3], coords=[("x", [1, 2, 3])])}) + y = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 2])])}) expected_x2 = x - expected_y2 = Dataset({'foo': DataArray([1, 2, np.nan], - coords=[('x', [1, 2, 3])])}) + expected_y2 = Dataset( + {"foo": DataArray([1, 2, np.nan], coords=[("x", [1, 2, 3])])} + ) - x2, y2 = align(x, y, copy=False, join='outer') + x2, y2 = align(x, y, copy=False, join="outer") assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) - assert source_ndarray(x['foo'].data) is source_ndarray(x2['foo'].data) + assert source_ndarray(x["foo"].data) is source_ndarray(x2["foo"].data) - x2, y2 = align(x, y, copy=True, join='outer') - assert source_ndarray(x['foo'].data) is not \ - source_ndarray(x2['foo'].data) + x2, y2 = align(x, y, copy=True, join="outer") + assert source_ndarray(x["foo"].data) is not source_ndarray(x2["foo"].data) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) def test_align_indexes(self): - x = Dataset({'foo': DataArray([1, 2, 3], dims='x', - coords=[('x', [1, 2, 3])])}) - x2, = align(x, indexes={'x': [2, 3, 1]}) - expected_x2 = Dataset({'foo': DataArray([2, 3, 1], dims='x', - coords={'x': [2, 3, 1]})}) + x = Dataset({"foo": DataArray([1, 2, 3], dims="x", coords=[("x", [1, 2, 3])])}) + x2, = align(x, indexes={"x": [2, 3, 1]}) + expected_x2 = Dataset( + {"foo": DataArray([2, 3, 1], dims="x", coords={"x": [2, 3, 1]})} + ) assert_identical(expected_x2, x2) def test_align_non_unique(self): - x = Dataset({'foo': ('x', [3, 4, 5]), 'x': [0, 0, 1]}) + x = Dataset({"foo": ("x", [3, 4, 5]), "x": [0, 0, 1]}) x1, x2 = align(x, x) assert x1.identical(x) and x2.identical(x) - y = Dataset({'bar': ('x', [6, 7]), 'x': [0, 1]}) - with raises_regex(ValueError, 'cannot reindex or align'): + y = Dataset({"bar": ("x", [6, 7]), "x": [0, 1]}) + with raises_regex(ValueError, "cannot reindex or align"): align(x, y) def test_broadcast(self): - ds = Dataset({'foo': 0, 'bar': ('x', [1]), 'baz': ('y', [2, 3])}, - {'c': ('x', [4])}) - expected = Dataset({'foo': (('x', 'y'), [[0, 0]]), - 'bar': (('x', 'y'), [[1, 1]]), - 'baz': (('x', 'y'), [[2, 3]])}, - {'c': ('x', [4])}) + ds = Dataset( + {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])} + ) + expected = Dataset( + { + "foo": (("x", "y"), [[0, 0]]), + "bar": (("x", "y"), [[1, 1]]), + "baz": (("x", "y"), [[2, 3]]), + }, + {"c": ("x", [4])}, + ) actual, = broadcast(ds) assert_identical(expected, actual) - ds_x = Dataset({'foo': ('x', [1])}) - ds_y = Dataset({'bar': ('y', [2, 3])}) - expected_x = Dataset({'foo': (('x', 'y'), [[1, 1]])}) - expected_y = Dataset({'bar': (('x', 'y'), [[2, 3]])}) + ds_x = Dataset({"foo": ("x", [1])}) + ds_y = Dataset({"bar": ("y", [2, 3])}) + expected_x = Dataset({"foo": (("x", "y"), [[1, 1]])}) + expected_y = Dataset({"bar": (("x", "y"), [[2, 3]])}) actual_x, actual_y = broadcast(ds_x, ds_y) assert_identical(expected_x, actual_x) assert_identical(expected_y, actual_y) - array_y = ds_y['bar'] - expected_y = expected_y['bar'] + array_y = ds_y["bar"] + expected_y = expected_y["bar"] actual_x, actual_y = broadcast(ds_x, array_y) assert_identical(expected_x, actual_x) assert_identical(expected_y, actual_y) def test_broadcast_nocopy(self): # Test that data is not copied if not needed - x = Dataset({'foo': (('x', 'y'), [[1, 1]])}) - y = Dataset({'bar': ('y', [2, 3])}) + x = Dataset({"foo": (("x", "y"), [[1, 1]])}) + y = Dataset({"bar": ("y", [2, 3])}) actual_x, = broadcast(x) assert_identical(x, actual_x) - assert source_ndarray(actual_x['foo'].data) is source_ndarray( - x['foo'].data) + assert source_ndarray(actual_x["foo"].data) is source_ndarray(x["foo"].data) actual_x, actual_y = broadcast(x, y) assert_identical(x, actual_x) - assert source_ndarray(actual_x['foo'].data) is source_ndarray( - x['foo'].data) + assert source_ndarray(actual_x["foo"].data) is source_ndarray(x["foo"].data) def test_broadcast_exclude(self): - x = Dataset({ - 'foo': DataArray([[1, 2], [3, 4]], dims=['x', 'y'], - coords={'x': [1, 2], 'y': [3, 4]}), - 'bar': DataArray(5), - }) - y = Dataset({ - 'foo': DataArray([[1, 2]], dims=['z', 'y'], - coords={'z': [1], 'y': [5, 6]}), - }) - x2, y2 = broadcast(x, y, exclude=['y']) - - expected_x2 = Dataset({ - 'foo': DataArray([[[1, 2]], [[3, 4]]], dims=['x', 'z', 'y'], - coords={'z': [1], 'x': [1, 2], 'y': [3, 4]}), - 'bar': DataArray([[5], [5]], dims=['x', 'z'], - coords={'x': [1, 2], 'z': [1]}), - }) - expected_y2 = Dataset({ - 'foo': DataArray([[[1, 2]], [[1, 2]]], dims=['x', 'z', 'y'], - coords={'z': [1], 'x': [1, 2], 'y': [5, 6]}), - }) + x = Dataset( + { + "foo": DataArray( + [[1, 2], [3, 4]], dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4]} + ), + "bar": DataArray(5), + } + ) + y = Dataset( + { + "foo": DataArray( + [[1, 2]], dims=["z", "y"], coords={"z": [1], "y": [5, 6]} + ) + } + ) + x2, y2 = broadcast(x, y, exclude=["y"]) + + expected_x2 = Dataset( + { + "foo": DataArray( + [[[1, 2]], [[3, 4]]], + dims=["x", "z", "y"], + coords={"z": [1], "x": [1, 2], "y": [3, 4]}, + ), + "bar": DataArray( + [[5], [5]], dims=["x", "z"], coords={"x": [1, 2], "z": [1]} + ), + } + ) + expected_y2 = Dataset( + { + "foo": DataArray( + [[[1, 2]], [[1, 2]]], + dims=["x", "z", "y"], + coords={"z": [1], "x": [1, 2], "y": [5, 6]}, + ) + } + ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) def test_broadcast_misaligned(self): - x = Dataset({'foo': DataArray([1, 2, 3], - coords=[('x', [-1, -2, -3])])}) - y = Dataset({'bar': DataArray([[1, 2], [3, 4]], dims=['y', 'x'], - coords={'y': [1, 2], 'x': [10, -3]})}) + x = Dataset({"foo": DataArray([1, 2, 3], coords=[("x", [-1, -2, -3])])}) + y = Dataset( + { + "bar": DataArray( + [[1, 2], [3, 4]], + dims=["y", "x"], + coords={"y": [1, 2], "x": [10, -3]}, + ) + } + ) x2, y2 = broadcast(x, y) expected_x2 = Dataset( - {'foo': DataArray([[3, 3], [2, 2], [1, 1], [np.nan, np.nan]], - dims=['x', 'y'], - coords={'y': [1, 2], 'x': [-3, -2, -1, 10]})}) + { + "foo": DataArray( + [[3, 3], [2, 2], [1, 1], [np.nan, np.nan]], + dims=["x", "y"], + coords={"y": [1, 2], "x": [-3, -2, -1, 10]}, + ) + } + ) expected_y2 = Dataset( - {'bar': DataArray( - [[2, 4], [np.nan, np.nan], [np.nan, np.nan], [1, 3]], - dims=['x', 'y'], coords={'y': [1, 2], 'x': [-3, -2, -1, 10]})}) + { + "bar": DataArray( + [[2, 4], [np.nan, np.nan], [np.nan, np.nan], [1, 3]], + dims=["x", "y"], + coords={"y": [1, 2], "x": [-3, -2, -1, 10]}, + ) + } + ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) def test_variable_indexing(self): data = create_test_data() - v = data['var1'] - d1 = data['dim1'] - d2 = data['dim2'] + v = data["var1"] + d1 = data["dim1"] + d2 = data["dim2"] assert_equal(v, v[d1.values]) assert_equal(v, v[d1]) assert_equal(v[:3], v[d1 < 3]) @@ -1955,35 +2107,32 @@ def test_drop_variables(self): assert_identical(data, data.drop([])) - expected = Dataset( - {k: data[k] for k in data.variables if k != 'time'} - ) - actual = data.drop('time') + expected = Dataset({k: data[k] for k in data.variables if k != "time"}) + actual = data.drop("time") assert_identical(expected, actual) - actual = data.drop(['time']) + actual = data.drop(["time"]) assert_identical(expected, actual) - with raises_regex(ValueError, 'cannot be found'): - data.drop('not_found_here') + with raises_regex(ValueError, "cannot be found"): + data.drop("not_found_here") - actual = data.drop('not_found_here', errors='ignore') + actual = data.drop("not_found_here", errors="ignore") assert_identical(data, actual) - actual = data.drop(['not_found_here'], errors='ignore') + actual = data.drop(["not_found_here"], errors="ignore") assert_identical(data, actual) - actual = data.drop(['time', 'not_found_here'], errors='ignore') + actual = data.drop(["time", "not_found_here"], errors="ignore") assert_identical(expected, actual) def test_drop_index_labels(self): - data = Dataset({'A': (['x', 'y'], np.random.randn(2, 3)), - 'x': ['a', 'b']}) + data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) - actual = data.drop(['a'], 'x') + actual = data.drop(["a"], "x") expected = data.isel(x=[1]) assert_identical(expected, actual) - actual = data.drop(['a', 'b'], 'x') + actual = data.drop(["a", "b"], "x") expected = data.isel(x=slice(0, 0)) assert_identical(expected, actual) @@ -1991,67 +2140,71 @@ def test_drop_index_labels(self): # in pandas 0.23. with pytest.raises((ValueError, KeyError)): # not contained in axis - data.drop(['c'], dim='x') + data.drop(["c"], dim="x") - actual = data.drop(['c'], dim='x', errors='ignore') + actual = data.drop(["c"], dim="x", errors="ignore") assert_identical(data, actual) with pytest.raises(ValueError): - data.drop(['c'], dim='x', errors='wrong_value') + data.drop(["c"], dim="x", errors="wrong_value") - actual = data.drop(['a', 'b', 'c'], 'x', errors='ignore') + actual = data.drop(["a", "b", "c"], "x", errors="ignore") expected = data.isel(x=slice(0, 0)) assert_identical(expected, actual) # DataArrays as labels are a nasty corner case as they are not # Iterable[Hashable] - DataArray.__iter__ yields scalar DataArrays. - actual = data.drop(DataArray(['a', 'b', 'c']), 'x', errors='ignore') + actual = data.drop(DataArray(["a", "b", "c"]), "x", errors="ignore") expected = data.isel(x=slice(0, 0)) assert_identical(expected, actual) - with raises_regex( - ValueError, 'does not have coordinate labels'): - data.drop(1, 'y') + with raises_regex(ValueError, "does not have coordinate labels"): + data.drop(1, "y") def test_drop_dims(self): - data = xr.Dataset({'A': (['x', 'y'], np.random.randn(2, 3)), - 'B': ('x', np.random.randn(2)), - 'x': ['a', 'b'], 'z': np.pi}) + data = xr.Dataset( + { + "A": (["x", "y"], np.random.randn(2, 3)), + "B": ("x", np.random.randn(2)), + "x": ["a", "b"], + "z": np.pi, + } + ) - actual = data.drop_dims('x') - expected = data.drop(['A', 'B', 'x']) + actual = data.drop_dims("x") + expected = data.drop(["A", "B", "x"]) assert_identical(expected, actual) - actual = data.drop_dims('y') - expected = data.drop('A') + actual = data.drop_dims("y") + expected = data.drop("A") assert_identical(expected, actual) - actual = data.drop_dims(['x', 'y']) - expected = data.drop(['A', 'B', 'x']) + actual = data.drop_dims(["x", "y"]) + expected = data.drop(["A", "B", "x"]) assert_identical(expected, actual) with pytest.raises((ValueError, KeyError)): - data.drop_dims('z') # not a dimension + data.drop_dims("z") # not a dimension with pytest.raises((ValueError, KeyError)): data.drop_dims(None) - actual = data.drop_dims('z', errors='ignore') + actual = data.drop_dims("z", errors="ignore") assert_identical(data, actual) - actual = data.drop_dims(None, errors='ignore') + actual = data.drop_dims(None, errors="ignore") assert_identical(data, actual) with pytest.raises(ValueError): - actual = data.drop_dims('z', errors='wrong_value') + actual = data.drop_dims("z", errors="wrong_value") - actual = data.drop_dims(['x', 'y', 'z'], errors='ignore') - expected = data.drop(['A', 'B', 'x']) + actual = data.drop_dims(["x", "y", "z"], errors="ignore") + expected = data.drop(["A", "B", "x"]) assert_identical(expected, actual) def test_copy(self): data = create_test_data() - data.attrs['Test'] = [1, 2, 3] + data.attrs["Test"] = [1, 2, 3] for copied in [data.copy(deep=False), copy(data)]: assert_identical(data, copied) @@ -2063,12 +2216,12 @@ def test_copy(self): v0 = data.variables[k] v1 = copied.variables[k] assert source_ndarray(v0.data) is source_ndarray(v1.data) - copied['foo'] = ('z', np.arange(5)) - assert 'foo' not in data + copied["foo"] = ("z", np.arange(5)) + assert "foo" not in data - copied.attrs['foo'] = 'bar' - assert 'foo' not in data.attrs - assert data.attrs['Test'] is copied.attrs['Test'] + copied.attrs["foo"] = "bar" + assert "foo" not in data.attrs + assert data.attrs["Test"] is copied.attrs["Test"] for copied in [data.copy(deep=True), deepcopy(data)]: assert_identical(data, copied) @@ -2076,12 +2229,11 @@ def test_copy(self): v1 = copied.variables[k] assert v0 is not v1 - assert data.attrs['Test'] is not copied.attrs['Test'] + assert data.attrs["Test"] is not copied.attrs["Test"] def test_copy_with_data(self): orig = create_test_data() - new_data = {k: np.random.randn(*v.shape) - for k, v in orig.data_vars.items()} + new_data = {k: np.random.randn(*v.shape) for k, v in orig.data_vars.items()} actual = orig.copy(data=new_data) expected = orig.copy() @@ -2090,45 +2242,62 @@ def test_copy_with_data(self): assert_identical(expected, actual) @pytest.mark.xfail(raises=AssertionError) - @pytest.mark.parametrize('deep, expected_orig', [ - [True, - xr.DataArray(xr.IndexVariable('a', np.array([1, 2])), - coords={'a': [1, 2]}, dims=['a'])], - [False, - xr.DataArray(xr.IndexVariable('a', np.array([999, 2])), - coords={'a': [999, 2]}, dims=['a'])]]) + @pytest.mark.parametrize( + "deep, expected_orig", + [ + [ + True, + xr.DataArray( + xr.IndexVariable("a", np.array([1, 2])), + coords={"a": [1, 2]}, + dims=["a"], + ), + ], + [ + False, + xr.DataArray( + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ), + ], + ], + ) def test_copy_coords(self, deep, expected_orig): """The test fails for the shallow copy, and apparently only on Windows for some reason. In windows coords seem to be immutable unless it's one dataset deep copied from another.""" ds = xr.DataArray( np.ones([2, 2, 2]), - coords={'a': [1, 2], 'b': ['x', 'y'], 'c': [0, 1]}, - dims=['a', 'b', 'c'], - name='value').to_dataset() + coords={"a": [1, 2], "b": ["x", "y"], "c": [0, 1]}, + dims=["a", "b", "c"], + name="value", + ).to_dataset() ds_cp = ds.copy(deep=deep) - ds_cp.coords['a'].data[0] = 999 + ds_cp.coords["a"].data[0] = 999 expected_cp = xr.DataArray( - xr.IndexVariable('a', np.array([999, 2])), - coords={'a': [999, 2]}, dims=['a']) - assert_identical(ds_cp.coords['a'], expected_cp) + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ) + assert_identical(ds_cp.coords["a"], expected_cp) - assert_identical(ds.coords['a'], expected_orig) + assert_identical(ds.coords["a"], expected_orig) def test_copy_with_data_errors(self): orig = create_test_data() - new_var1 = np.arange(orig['var1'].size).reshape(orig['var1'].shape) - with raises_regex(ValueError, 'Data must be dict-like'): + new_var1 = np.arange(orig["var1"].size).reshape(orig["var1"].shape) + with raises_regex(ValueError, "Data must be dict-like"): orig.copy(data=new_var1) - with raises_regex(ValueError, 'only contain variables in original'): - orig.copy(data={'not_in_original': new_var1}) - with raises_regex(ValueError, 'contain all variables in original'): - orig.copy(data={'var1': new_var1}) + with raises_regex(ValueError, "only contain variables in original"): + orig.copy(data={"not_in_original": new_var1}) + with raises_regex(ValueError, "contain all variables in original"): + orig.copy(data={"var1": new_var1}) def test_rename(self): data = create_test_data() - newnames = {'var1': 'renamed_var1', 'dim2': 'renamed_dim2'} + newnames = {"var1": "renamed_var1", "dim2": "renamed_dim2"} renamed = data.rename(newnames) variables = OrderedDict(data.variables) @@ -2141,26 +2310,28 @@ def test_rename(self): if name in dims: dims[dims.index(name)] = newname - assert_equal(Variable(dims, v.values, v.attrs), - renamed[k].variable.to_base_variable()) + assert_equal( + Variable(dims, v.values, v.attrs), + renamed[k].variable.to_base_variable(), + ) assert v.encoding == renamed[k].encoding assert type(v) == type(renamed.variables[k]) # noqa: E721 - assert 'var1' not in renamed - assert 'dim2' not in renamed + assert "var1" not in renamed + assert "dim2" not in renamed with raises_regex(ValueError, "cannot rename 'not_a_var'"): - data.rename({'not_a_var': 'nada'}) + data.rename({"not_a_var": "nada"}) with raises_regex(ValueError, "'var1' conflicts"): - data.rename({'var2': 'var1'}) + data.rename({"var2": "var1"}) # verify that we can rename a variable without accessing the data - var1 = data['var1'] - data['var1'] = (var1.dims, InaccessibleArray(var1.values)) + var1 = data["var1"] + data["var1"] = (var1.dims, InaccessibleArray(var1.values)) renamed = data.rename(newnames) with pytest.raises(UnexpectedDataAccess): - renamed['renamed_var1'].values + renamed["renamed_var1"].values renamed_kwargs = data.rename(**newnames) assert_identical(renamed, renamed_kwargs) @@ -2170,197 +2341,231 @@ def test_rename_old_name(self): data = create_test_data() with raises_regex(ValueError, "'samecol' conflicts"): - data.rename({'var1': 'samecol', 'var2': 'samecol'}) + data.rename({"var1": "samecol", "var2": "samecol"}) # This shouldn't cause any problems. - data.rename({'var1': 'var2', 'var2': 'var1'}) + data.rename({"var1": "var2", "var2": "var1"}) def test_rename_same_name(self): data = create_test_data() - newnames = {'var1': 'var1', 'dim2': 'dim2'} + newnames = {"var1": "var1", "dim2": "dim2"} renamed = data.rename(newnames) assert_identical(renamed, data) - @pytest.mark.filterwarnings('ignore:The inplace argument') + @pytest.mark.filterwarnings("ignore:The inplace argument") def test_rename_inplace(self): - times = pd.date_range('2000-01-01', periods=3) - data = Dataset({'z': ('x', [2, 3, 4]), 't': ('t', times)}) + times = pd.date_range("2000-01-01", periods=3) + data = Dataset({"z": ("x", [2, 3, 4]), "t": ("t", times)}) copied = data.copy() - renamed = data.rename({'x': 'y'}) - data.rename({'x': 'y'}, inplace=True) + renamed = data.rename({"x": "y"}) + data.rename({"x": "y"}, inplace=True) assert_identical(data, renamed) assert not data.equals(copied) - assert data.dims == {'y': 3, 't': 3} + assert data.dims == {"y": 3, "t": 3} # check virtual variables - assert_array_equal(data['t.dayofyear'], [1, 2, 3]) + assert_array_equal(data["t.dayofyear"], [1, 2, 3]) def test_rename_dims(self): - original = Dataset( - {'x': ('x', [0, 1, 2]), 'y': ('x', [10, 11, 12]), 'z': 42}) + original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) expected = Dataset( - {'x': ('x_new', [0, 1, 2]), 'y': ('x_new', [10, 11, 12]), 'z': 42}) - expected = expected.set_coords('x') - dims_dict = {'x': 'x_new'} + {"x": ("x_new", [0, 1, 2]), "y": ("x_new", [10, 11, 12]), "z": 42} + ) + expected = expected.set_coords("x") + dims_dict = {"x": "x_new"} actual = original.rename_dims(dims_dict) assert_identical(expected, actual) actual_2 = original.rename_dims(**dims_dict) assert_identical(expected, actual_2) # Test to raise ValueError - dims_dict_bad = {'x_bad': 'x_new'} + dims_dict_bad = {"x_bad": "x_new"} with pytest.raises(ValueError): original.rename_dims(dims_dict_bad) def test_rename_vars(self): - original = Dataset( - {'x': ('x', [0, 1, 2]), 'y': ('x', [10, 11, 12]), 'z': 42}) + original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) expected = Dataset( - {'x_new': ('x', [0, 1, 2]), 'y': ('x', [10, 11, 12]), 'z': 42}) - expected = expected.set_coords('x_new') - name_dict = {'x': 'x_new'} + {"x_new": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42} + ) + expected = expected.set_coords("x_new") + name_dict = {"x": "x_new"} actual = original.rename_vars(name_dict) assert_identical(expected, actual) actual_2 = original.rename_vars(**name_dict) assert_identical(expected, actual_2) # Test to raise ValueError - names_dict_bad = {'x_bad': 'x_new'} + names_dict_bad = {"x_bad": "x_new"} with pytest.raises(ValueError): original.rename_vars(names_dict_bad) def test_swap_dims(self): - original = Dataset({'x': [1, 2, 3], 'y': ('x', list('abc')), 'z': 42}) - expected = Dataset({'z': 42}, - {'x': ('y', [1, 2, 3]), 'y': list('abc')}) - actual = original.swap_dims({'x': 'y'}) + original = Dataset({"x": [1, 2, 3], "y": ("x", list("abc")), "z": 42}) + expected = Dataset({"z": 42}, {"x": ("y", [1, 2, 3]), "y": list("abc")}) + actual = original.swap_dims({"x": "y"}) assert_identical(expected, actual) - assert isinstance(actual.variables['y'], IndexVariable) - assert isinstance(actual.variables['x'], Variable) - assert actual.indexes['y'].equals(pd.Index(list('abc'))) + assert isinstance(actual.variables["y"], IndexVariable) + assert isinstance(actual.variables["x"], Variable) + assert actual.indexes["y"].equals(pd.Index(list("abc"))) - roundtripped = actual.swap_dims({'y': 'x'}) - assert_identical(original.set_coords('y'), roundtripped) + roundtripped = actual.swap_dims({"y": "x"}) + assert_identical(original.set_coords("y"), roundtripped) - with raises_regex(ValueError, 'cannot swap'): - original.swap_dims({'y': 'x'}) - with raises_regex(ValueError, 'replacement dimension'): - original.swap_dims({'x': 'z'}) + with raises_regex(ValueError, "cannot swap"): + original.swap_dims({"y": "x"}) + with raises_regex(ValueError, "replacement dimension"): + original.swap_dims({"x": "z"}) def test_expand_dims_error(self): - original = Dataset({'x': ('a', np.random.randn(3)), - 'y': (['b', 'a'], np.random.randn(4, 3)), - 'z': ('a', np.random.randn(3))}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) + original = Dataset( + { + "x": ("a", np.random.randn(3)), + "y": (["b", "a"], np.random.randn(4, 3)), + "z": ("a", np.random.randn(3)), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) - with raises_regex(ValueError, 'already exists'): - original.expand_dims(dim=['x']) + with raises_regex(ValueError, "already exists"): + original.expand_dims(dim=["x"]) # Make sure it raises true error also for non-dimensional coordinates # which has dimension. - original = original.set_coords('z') - with raises_regex(ValueError, 'already exists'): - original.expand_dims(dim=['z']) - - original = Dataset({'x': ('a', np.random.randn(3)), - 'y': (['b', 'a'], np.random.randn(4, 3)), - 'z': ('a', np.random.randn(3))}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) - with raises_regex(TypeError, 'value of new dimension'): + original = original.set_coords("z") + with raises_regex(ValueError, "already exists"): + original.expand_dims(dim=["z"]) + + original = Dataset( + { + "x": ("a", np.random.randn(3)), + "y": (["b", "a"], np.random.randn(4, 3)), + "z": ("a", np.random.randn(3)), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + with raises_regex(TypeError, "value of new dimension"): original.expand_dims(OrderedDict((("d", 3.2),))) # TODO: only the code under the if-statement is needed when python 3.5 # is no longer supported. python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 if python36_plus: - with raises_regex(ValueError, 'both keyword and positional'): + with raises_regex(ValueError, "both keyword and positional"): original.expand_dims(OrderedDict((("d", 4),)), e=4) def test_expand_dims_int(self): - original = Dataset({'x': ('a', np.random.randn(3)), - 'y': (['b', 'a'], np.random.randn(4, 3))}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) - - actual = original.expand_dims(['z'], [1]) - expected = Dataset({'x': original['x'].expand_dims('z', 1), - 'y': original['y'].expand_dims('z', 1)}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) + original = Dataset( + {"x": ("a", np.random.randn(3)), "y": (["b", "a"], np.random.randn(4, 3))}, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + + actual = original.expand_dims(["z"], [1]) + expected = Dataset( + { + "x": original["x"].expand_dims("z", 1), + "y": original["y"].expand_dims("z", 1), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) # make sure squeeze restores the original data set. - roundtripped = actual.squeeze('z') + roundtripped = actual.squeeze("z") assert_identical(original, roundtripped) # another test with a negative axis - actual = original.expand_dims(['z'], [-1]) - expected = Dataset({'x': original['x'].expand_dims('z', -1), - 'y': original['y'].expand_dims('z', -1)}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) + actual = original.expand_dims(["z"], [-1]) + expected = Dataset( + { + "x": original["x"].expand_dims("z", -1), + "y": original["y"].expand_dims("z", -1), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) assert_identical(expected, actual) # make sure squeeze restores the original data set. - roundtripped = actual.squeeze('z') + roundtripped = actual.squeeze("z") assert_identical(original, roundtripped) def test_expand_dims_coords(self): - original = Dataset({'x': ('a', np.array([1, 2, 3]))}) + original = Dataset({"x": ("a", np.array([1, 2, 3]))}) expected = Dataset( - {'x': (('b', 'a'), np.array([[1, 2, 3], [1, 2, 3]]))}, - coords={'b': [1, 2]}, + {"x": (("b", "a"), np.array([[1, 2, 3], [1, 2, 3]]))}, coords={"b": [1, 2]} ) actual = original.expand_dims(OrderedDict(b=[1, 2])) assert_identical(expected, actual) - assert 'b' not in original._coord_names + assert "b" not in original._coord_names def test_expand_dims_existing_scalar_coord(self): - original = Dataset({'x': 1}, {'a': 2}) - expected = Dataset({'x': (('a',), [1])}, {'a': [2]}) - actual = original.expand_dims('a') + original = Dataset({"x": 1}, {"a": 2}) + expected = Dataset({"x": (("a",), [1])}, {"a": [2]}) + actual = original.expand_dims("a") assert_identical(expected, actual) def test_isel_expand_dims_roundtrip(self): - original = Dataset({'x': (('a',), [1])}, {'a': [2]}) - actual = original.isel(a=0).expand_dims('a') + original = Dataset({"x": (("a",), [1])}, {"a": [2]}) + actual = original.isel(a=0).expand_dims("a") assert_identical(actual, original) def test_expand_dims_mixed_int_and_coords(self): # Test expanding one dimension to have size > 1 that doesn't have # coordinates, and also expanding another dimension to have size > 1 # that DOES have coordinates. - original = Dataset({'x': ('a', np.random.randn(3)), - 'y': (['b', 'a'], np.random.randn(4, 3))}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}) + original = Dataset( + {"x": ("a", np.random.randn(3)), "y": (["b", "a"], np.random.randn(4, 3))}, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + ) - actual = original.expand_dims( - OrderedDict((("d", 4), ("e", ["l", "m", "n"])))) + actual = original.expand_dims(OrderedDict((("d", 4), ("e", ["l", "m", "n"])))) expected = Dataset( - {'x': xr.DataArray(original['x'].values * np.ones([4, 3, 3]), - coords=dict(d=range(4), - e=['l', 'm', 'n'], - a=np.linspace(0, 1, 3)), - dims=['d', 'e', 'a']).drop('d'), - 'y': xr.DataArray(original['y'].values * np.ones([4, 3, 4, 3]), - coords=dict(d=range(4), - e=['l', 'm', 'n'], - b=np.linspace(0, 1, 4), - a=np.linspace(0, 1, 3)), - dims=['d', 'e', 'b', 'a']).drop('d')}, - coords={'c': np.linspace(0, 1, 5)}) + { + "x": xr.DataArray( + original["x"].values * np.ones([4, 3, 3]), + coords=dict(d=range(4), e=["l", "m", "n"], a=np.linspace(0, 1, 3)), + dims=["d", "e", "a"], + ).drop("d"), + "y": xr.DataArray( + original["y"].values * np.ones([4, 3, 4, 3]), + coords=dict( + d=range(4), + e=["l", "m", "n"], + b=np.linspace(0, 1, 4), + a=np.linspace(0, 1, 3), + ), + dims=["d", "e", "b", "a"], + ).drop("d"), + }, + coords={"c": np.linspace(0, 1, 5)}, + ) assert_identical(actual, expected) @pytest.mark.skipif( @@ -2368,188 +2573,202 @@ def test_expand_dims_mixed_int_and_coords(self): reason="we only raise these errors for Python 3.5", ) def test_expand_dims_kwargs_python35(self): - original = Dataset({'x': ('a', np.random.randn(3))}) + original = Dataset({"x": ("a", np.random.randn(3))}) with raises_regex(ValueError, "dim_kwargs isn't"): original.expand_dims(e=["l", "m", "n"]) with raises_regex(TypeError, "must be an OrderedDict"): - original.expand_dims({'e': ["l", "m", "n"]}) + original.expand_dims({"e": ["l", "m", "n"]}) @pytest.mark.skipif( sys.version_info[:2] < (3, 6), - reason='keyword arguments are only ordered on Python 3.6+', + reason="keyword arguments are only ordered on Python 3.6+", ) def test_expand_dims_kwargs_python36plus(self): - original = Dataset({'x': ('a', np.random.randn(3)), - 'y': (['b', 'a'], np.random.randn(4, 3))}, - coords={'a': np.linspace(0, 1, 3), - 'b': np.linspace(0, 1, 4), - 'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) + original = Dataset( + {"x": ("a", np.random.randn(3)), "y": (["b", "a"], np.random.randn(4, 3))}, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) other_way = original.expand_dims(e=["l", "m", "n"]) other_way_expected = Dataset( - {'x': xr.DataArray(original['x'].values * np.ones([3, 3]), - coords=dict(e=['l', 'm', 'n'], - a=np.linspace(0, 1, 3)), - dims=['e', 'a']), - 'y': xr.DataArray(original['y'].values * np.ones([3, 4, 3]), - coords=dict(e=['l', 'm', 'n'], - b=np.linspace(0, 1, 4), - a=np.linspace(0, 1, 3)), - dims=['e', 'b', 'a'])}, - coords={'c': np.linspace(0, 1, 5)}, - attrs={'key': 'entry'}) + { + "x": xr.DataArray( + original["x"].values * np.ones([3, 3]), + coords=dict(e=["l", "m", "n"], a=np.linspace(0, 1, 3)), + dims=["e", "a"], + ), + "y": xr.DataArray( + original["y"].values * np.ones([3, 4, 3]), + coords=dict( + e=["l", "m", "n"], + b=np.linspace(0, 1, 4), + a=np.linspace(0, 1, 3), + ), + dims=["e", "b", "a"], + ), + }, + coords={"c": np.linspace(0, 1, 5)}, + attrs={"key": "entry"}, + ) assert_identical(other_way_expected, other_way) def test_set_index(self): expected = create_test_multiindex() - mindex = expected['x'].to_index() + mindex = expected["x"].to_index() indexes = [mindex.get_level_values(n) for n in mindex.names] - coords = {idx.name: ('x', idx) for idx in indexes} + coords = {idx.name: ("x", idx) for idx in indexes} ds = Dataset({}, coords=coords) obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) - with pytest.warns(FutureWarning, match='The inplace argument'): + with pytest.warns(FutureWarning, match="The inplace argument"): ds.set_index(x=mindex.names, inplace=True) assert_identical(ds, expected) # ensure set_index with no existing index and a single data var given # doesn't return multi-index - ds = Dataset(data_vars={'x_var': ('x', [0, 1, 2])}) - expected = Dataset(coords={'x': [0, 1, 2]}) - assert_identical(ds.set_index(x='x_var'), expected) + ds = Dataset(data_vars={"x_var": ("x", [0, 1, 2])}) + expected = Dataset(coords={"x": [0, 1, 2]}) + assert_identical(ds.set_index(x="x_var"), expected) def test_reset_index(self): ds = create_test_multiindex() - mindex = ds['x'].to_index() + mindex = ds["x"].to_index() indexes = [mindex.get_level_values(n) for n in mindex.names] - coords = {idx.name: ('x', idx) for idx in indexes} + coords = {idx.name: ("x", idx) for idx in indexes} expected = Dataset({}, coords=coords) - obj = ds.reset_index('x') + obj = ds.reset_index("x") assert_identical(obj, expected) - with pytest.warns(FutureWarning, match='The inplace argument'): - ds.reset_index('x', inplace=True) + with pytest.warns(FutureWarning, match="The inplace argument"): + ds.reset_index("x", inplace=True) assert_identical(ds, expected) def test_reorder_levels(self): ds = create_test_multiindex() - mindex = ds['x'].to_index() - midx = mindex.reorder_levels(['level_2', 'level_1']) - expected = Dataset({}, coords={'x': midx}) + mindex = ds["x"].to_index() + midx = mindex.reorder_levels(["level_2", "level_1"]) + expected = Dataset({}, coords={"x": midx}) - reindexed = ds.reorder_levels(x=['level_2', 'level_1']) + reindexed = ds.reorder_levels(x=["level_2", "level_1"]) assert_identical(reindexed, expected) - with pytest.warns(FutureWarning, match='The inplace argument'): - ds.reorder_levels(x=['level_2', 'level_1'], inplace=True) + with pytest.warns(FutureWarning, match="The inplace argument"): + ds.reorder_levels(x=["level_2", "level_1"], inplace=True) assert_identical(ds, expected) - ds = Dataset({}, coords={'x': [1, 2]}) - with raises_regex(ValueError, 'has no MultiIndex'): - ds.reorder_levels(x=['level_1', 'level_2']) + ds = Dataset({}, coords={"x": [1, 2]}) + with raises_regex(ValueError, "has no MultiIndex"): + ds.reorder_levels(x=["level_1", "level_2"]) def test_stack(self): - ds = Dataset({'a': ('x', [0, 1]), - 'b': (('x', 'y'), [[0, 1], [2, 3]]), - 'y': ['a', 'b']}) + ds = Dataset( + {"a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), "y": ["a", "b"]} + ) - exp_index = pd.MultiIndex.from_product([[0, 1], ['a', 'b']], - names=['x', 'y']) - expected = Dataset({'a': ('z', [0, 0, 1, 1]), - 'b': ('z', [0, 1, 2, 3]), - 'z': exp_index}) - actual = ds.stack(z=['x', 'y']) + exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) + expected = Dataset( + {"a": ("z", [0, 0, 1, 1]), "b": ("z", [0, 1, 2, 3]), "z": exp_index} + ) + actual = ds.stack(z=["x", "y"]) assert_identical(expected, actual) - exp_index = pd.MultiIndex.from_product([['a', 'b'], [0, 1]], - names=['y', 'x']) - expected = Dataset({'a': ('z', [0, 1, 0, 1]), - 'b': ('z', [0, 2, 1, 3]), - 'z': exp_index}) - actual = ds.stack(z=['y', 'x']) + exp_index = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["y", "x"]) + expected = Dataset( + {"a": ("z", [0, 1, 0, 1]), "b": ("z", [0, 2, 1, 3]), "z": exp_index} + ) + actual = ds.stack(z=["y", "x"]) assert_identical(expected, actual) def test_unstack(self): - index = pd.MultiIndex.from_product([[0, 1], ['a', 'b']], - names=['x', 'y']) - ds = Dataset({'b': ('z', [0, 1, 2, 3]), 'z': index}) - expected = Dataset({'b': (('x', 'y'), [[0, 1], [2, 3]]), - 'x': [0, 1], - 'y': ['a', 'b']}) - for dim in ['z', ['z'], None]: + index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) + ds = Dataset({"b": ("z", [0, 1, 2, 3]), "z": index}) + expected = Dataset( + {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} + ) + for dim in ["z", ["z"], None]: actual = ds.unstack(dim) assert_identical(actual, expected) def test_unstack_errors(self): - ds = Dataset({'x': [1, 2, 3]}) - with raises_regex(ValueError, 'does not contain the dimensions'): - ds.unstack('foo') - with raises_regex(ValueError, 'do not have a MultiIndex'): - ds.unstack('x') + ds = Dataset({"x": [1, 2, 3]}) + with raises_regex(ValueError, "does not contain the dimensions"): + ds.unstack("foo") + with raises_regex(ValueError, "do not have a MultiIndex"): + ds.unstack("x") def test_stack_unstack_fast(self): - ds = Dataset({'a': ('x', [0, 1]), - 'b': (('x', 'y'), [[0, 1], [2, 3]]), - 'x': [0, 1], - 'y': ['a', 'b']}) - actual = ds.stack(z=['x', 'y']).unstack('z') + ds = Dataset( + { + "a": ("x", [0, 1]), + "b": (("x", "y"), [[0, 1], [2, 3]]), + "x": [0, 1], + "y": ["a", "b"], + } + ) + actual = ds.stack(z=["x", "y"]).unstack("z") assert actual.broadcast_equals(ds) - actual = ds[['b']].stack(z=['x', 'y']).unstack('z') - assert actual.identical(ds[['b']]) + actual = ds[["b"]].stack(z=["x", "y"]).unstack("z") + assert actual.identical(ds[["b"]]) def test_stack_unstack_slow(self): - ds = Dataset({'a': ('x', [0, 1]), - 'b': (('x', 'y'), [[0, 1], [2, 3]]), - 'x': [0, 1], - 'y': ['a', 'b']}) - stacked = ds.stack(z=['x', 'y']) - actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + ds = Dataset( + { + "a": ("x", [0, 1]), + "b": (("x", "y"), [[0, 1], [2, 3]]), + "x": [0, 1], + "y": ["a", "b"], + } + ) + stacked = ds.stack(z=["x", "y"]) + actual = stacked.isel(z=slice(None, None, -1)).unstack("z") assert actual.broadcast_equals(ds) - stacked = ds[['b']].stack(z=['x', 'y']) - actual = stacked.isel(z=slice(None, None, -1)).unstack('z') - assert actual.identical(ds[['b']]) + stacked = ds[["b"]].stack(z=["x", "y"]) + actual = stacked.isel(z=slice(None, None, -1)).unstack("z") + assert actual.identical(ds[["b"]]) def test_to_stacked_array_invalid_sample_dims(self): data = xr.Dataset( - data_vars={'a': (('x', 'y'), [[0, 1, 2], [3, 4, 5]]), - 'b': ('x', [6, 7])}, - coords={'y': ['u', 'v', 'w']} + data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])}, + coords={"y": ["u", "v", "w"]}, ) with pytest.raises(ValueError): - data.to_stacked_array("features", sample_dims=['y']) + data.to_stacked_array("features", sample_dims=["y"]) def test_to_stacked_array_name(self): - name = 'adf9d' + name = "adf9d" # make a two dimensional dataset a, b = create_test_stacked_array() - D = xr.Dataset({'a': a, 'b': b}) - sample_dims = ['x'] + D = xr.Dataset({"a": a, "b": b}) + sample_dims = ["x"] - y = D.to_stacked_array('features', sample_dims, name=name) + y = D.to_stacked_array("features", sample_dims, name=name) assert y.name == name def test_to_stacked_array_dtype_dims(self): # make a two dimensional dataset a, b = create_test_stacked_array() - D = xr.Dataset({'a': a, 'b': b}) - sample_dims = ['x'] - y = D.to_stacked_array('features', sample_dims) - assert y.indexes['features'].levels[1].dtype == D.y.dtype - assert y.dims == ('x', 'features') + D = xr.Dataset({"a": a, "b": b}) + sample_dims = ["x"] + y = D.to_stacked_array("features", sample_dims) + assert y.indexes["features"].levels[1].dtype == D.y.dtype + assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self): # make a two dimensional dataset a, b = create_test_stacked_array() - D = xr.Dataset({'a': a, 'b': b}) - sample_dims = ['x'] - y = D.to_stacked_array('features', sample_dims)\ - .transpose("x", "features") + D = xr.Dataset({"a": a, "b": b}) + sample_dims = ["x"] + y = D.to_stacked_array("features", sample_dims).transpose("x", "features") x = y.to_unstacked_dataset("features") assert_identical(D, x) @@ -2562,19 +2781,19 @@ def test_to_stacked_array_to_unstacked_dataset(self): def test_to_stacked_array_to_unstacked_dataset_different_dimension(self): # test when variables have different dimensionality a, b = create_test_stacked_array() - sample_dims = ['x'] - D = xr.Dataset({'a': a, 'b': b.isel(y=0)}) + sample_dims = ["x"] + D = xr.Dataset({"a": a, "b": b.isel(y=0)}) - y = D.to_stacked_array('features', sample_dims) - x = y.to_unstacked_dataset('features') + y = D.to_stacked_array("features", sample_dims) + x = y.to_unstacked_dataset("features") assert_identical(D, x) def test_update(self): data = create_test_data(seed=0) expected = data.copy() - var2 = Variable('dim1', np.arange(8)) - actual = data.update({'var2': var2}) - expected['var2'] = var2 + var2 = Variable("dim1", np.arange(8)) + actual = data.update({"var2": var2}) + expected["var2"] = var2 assert_identical(expected, actual) actual = data.copy() @@ -2582,70 +2801,72 @@ def test_update(self): assert actual_result is actual assert_identical(expected, actual) - with pytest.warns(FutureWarning, match='The inplace argument'): + with pytest.warns(FutureWarning, match="The inplace argument"): actual = data.update(data, inplace=False) expected = data assert actual is not expected assert_identical(expected, actual) - other = Dataset(attrs={'new': 'attr'}) + other = Dataset(attrs={"new": "attr"}) actual = data.copy() actual.update(other) assert_identical(expected, actual) def test_update_overwrite_coords(self): - data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) - data.update(Dataset(coords={'b': 4})) - expected = Dataset({'a': ('x', [1, 2])}, {'b': 4}) + data = Dataset({"a": ("x", [1, 2])}, {"b": 3}) + data.update(Dataset(coords={"b": 4})) + expected = Dataset({"a": ("x", [1, 2])}, {"b": 4}) assert_identical(data, expected) - data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) - data.update(Dataset({'c': 5}, coords={'b': 4})) - expected = Dataset({'a': ('x', [1, 2]), 'c': 5}, {'b': 4}) + data = Dataset({"a": ("x", [1, 2])}, {"b": 3}) + data.update(Dataset({"c": 5}, coords={"b": 4})) + expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 4}) assert_identical(data, expected) - data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) - data.update({'c': DataArray(5, coords={'b': 4})}) - expected = Dataset({'a': ('x', [1, 2]), 'c': 5}, {'b': 3}) + data = Dataset({"a": ("x", [1, 2])}, {"b": 3}) + data.update({"c": DataArray(5, coords={"b": 4})}) + expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 3}) assert_identical(data, expected) def test_update_auto_align(self): - ds = Dataset({'x': ('t', [3, 4])}, {'t': [0, 1]}) + ds = Dataset({"x": ("t", [3, 4])}, {"t": [0, 1]}) - expected = Dataset({'x': ('t', [3, 4]), 'y': ('t', [np.nan, 5])}, - {'t': [0, 1]}) + expected = Dataset({"x": ("t", [3, 4]), "y": ("t", [np.nan, 5])}, {"t": [0, 1]}) actual = ds.copy() - other = {'y': ('t', [5]), 't': [1]} - with raises_regex(ValueError, 'conflicting sizes'): + other = {"y": ("t", [5]), "t": [1]} + with raises_regex(ValueError, "conflicting sizes"): actual.update(other) actual.update(Dataset(other)) assert_identical(expected, actual) actual = ds.copy() - other = Dataset({'y': ('t', [5]), 't': [100]}) + other = Dataset({"y": ("t", [5]), "t": [100]}) actual.update(other) - expected = Dataset({'x': ('t', [3, 4]), 'y': ('t', [np.nan] * 2)}, - {'t': [0, 1]}) + expected = Dataset( + {"x": ("t", [3, 4]), "y": ("t", [np.nan] * 2)}, {"t": [0, 1]} + ) assert_identical(expected, actual) def test_getitem(self): data = create_test_data() - assert isinstance(data['var1'], DataArray) - assert_equal(data['var1'].variable, data.variables['var1']) + assert isinstance(data["var1"], DataArray) + assert_equal(data["var1"].variable, data.variables["var1"]) with pytest.raises(KeyError): - data['notfound'] + data["notfound"] with pytest.raises(KeyError): - data[['var1', 'notfound']] + data[["var1", "notfound"]] - actual = data[['var1', 'var2']] - expected = Dataset({'var1': data['var1'], 'var2': data['var2']}) + actual = data[["var1", "var2"]] + expected = Dataset({"var1": data["var1"], "var2": data["var2"]}) assert_equal(expected, actual) - actual = data['numbers'] - expected = DataArray(data['numbers'].variable, - {'dim3': data['dim3'], - 'numbers': data['numbers']}, - dims='dim3', name='numbers') + actual = data["numbers"] + expected = DataArray( + data["numbers"].variable, + {"dim3": data["dim3"], "numbers": data["numbers"]}, + dims="dim3", + name="numbers", + ) assert_identical(expected, actual) actual = data[dict(dim1=0)] @@ -2654,235 +2875,246 @@ def test_getitem(self): def test_getitem_hashable(self): data = create_test_data() - data[(3, 4)] = data['var1'] + 1 - expected = data['var1'] + 1 + data[(3, 4)] = data["var1"] + 1 + expected = data["var1"] + 1 expected.name = (3, 4) assert_identical(expected, data[(3, 4)]) with raises_regex(KeyError, "('var1', 'var2')"): - data[('var1', 'var2')] + data[("var1", "var2")] def test_virtual_variables_default_coords(self): - dataset = Dataset({'foo': ('x', range(10))}) - expected = DataArray(range(10), dims='x', name='x') - actual = dataset['x'] + dataset = Dataset({"foo": ("x", range(10))}) + expected = DataArray(range(10), dims="x", name="x") + actual = dataset["x"] assert_identical(expected, actual) assert isinstance(actual.variable, IndexVariable) - actual = dataset[['x', 'foo']] + actual = dataset[["x", "foo"]] expected = dataset.assign_coords(x=range(10)) assert_identical(expected, actual) def test_virtual_variables_time(self): # access virtual variables data = create_test_data() - expected = DataArray(1 + np.arange(20), coords=[data['time']], - dims='time', name='dayofyear') + expected = DataArray( + 1 + np.arange(20), coords=[data["time"]], dims="time", name="dayofyear" + ) - assert_array_equal(data['time.month'].values, - data.variables['time'].to_index().month) - assert_array_equal(data['time.season'].values, 'DJF') + assert_array_equal( + data["time.month"].values, data.variables["time"].to_index().month + ) + assert_array_equal(data["time.season"].values, "DJF") # test virtual variable math - assert_array_equal(data['time.dayofyear'] + 1, 2 + np.arange(20)) - assert_array_equal(np.sin(data['time.dayofyear']), - np.sin(1 + np.arange(20))) + assert_array_equal(data["time.dayofyear"] + 1, 2 + np.arange(20)) + assert_array_equal(np.sin(data["time.dayofyear"]), np.sin(1 + np.arange(20))) # ensure they become coordinates - expected = Dataset({}, {'dayofyear': data['time.dayofyear']}) - actual = data[['time.dayofyear']] + expected = Dataset({}, {"dayofyear": data["time.dayofyear"]}) + actual = data[["time.dayofyear"]] assert_equal(expected, actual) # non-coordinate variables - ds = Dataset({'t': ('x', pd.date_range('2000-01-01', periods=3))}) - assert (ds['t.year'] == 2000).all() + ds = Dataset({"t": ("x", pd.date_range("2000-01-01", periods=3))}) + assert (ds["t.year"] == 2000).all() def test_virtual_variable_same_name(self): # regression test for GH367 - times = pd.date_range('2000-01-01', freq='H', periods=5) - data = Dataset({'time': times}) - actual = data['time.time'] - expected = DataArray(times.time, [('time', times)], name='time') + times = pd.date_range("2000-01-01", freq="H", periods=5) + data = Dataset({"time": times}) + actual = data["time.time"] + expected = DataArray(times.time, [("time", times)], name="time") assert_identical(actual, expected) def test_virtual_variable_multiindex(self): # access multi-index levels as virtual variables data = create_test_multiindex() - expected = DataArray(['a', 'a', 'b', 'b'], name='level_1', - coords=[data['x'].to_index()], dims='x') - assert_identical(expected, data['level_1']) + expected = DataArray( + ["a", "a", "b", "b"], + name="level_1", + coords=[data["x"].to_index()], + dims="x", + ) + assert_identical(expected, data["level_1"]) # combine multi-index level and datetime - dr_index = pd.date_range('1/1/2011', periods=4, freq='H') - mindex = pd.MultiIndex.from_arrays([['a', 'a', 'b', 'b'], dr_index], - names=('level_str', 'level_date')) - data = Dataset({}, {'x': mindex}) - expected = DataArray(mindex.get_level_values('level_date').hour, - name='hour', coords=[mindex], dims='x') - assert_identical(expected, data['level_date.hour']) + dr_index = pd.date_range("1/1/2011", periods=4, freq="H") + mindex = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b"], dr_index], names=("level_str", "level_date") + ) + data = Dataset({}, {"x": mindex}) + expected = DataArray( + mindex.get_level_values("level_date").hour, + name="hour", + coords=[mindex], + dims="x", + ) + assert_identical(expected, data["level_date.hour"]) # attribute style access - assert_identical(data.level_str, data['level_str']) + assert_identical(data.level_str, data["level_str"]) def test_time_season(self): - ds = Dataset({'t': pd.date_range('2000-01-01', periods=12, freq='M')}) - seas = ['DJF'] * 2 + ['MAM'] * 3 + ['JJA'] * 3 + ['SON'] * 3 + ['DJF'] - assert_array_equal(seas, ds['t.season']) + ds = Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) + seas = ["DJF"] * 2 + ["MAM"] * 3 + ["JJA"] * 3 + ["SON"] * 3 + ["DJF"] + assert_array_equal(seas, ds["t.season"]) def test_slice_virtual_variable(self): data = create_test_data() - assert_equal(data['time.dayofyear'][:10].variable, - Variable(['time'], 1 + np.arange(10))) assert_equal( - data['time.dayofyear'][0].variable, Variable([], 1)) + data["time.dayofyear"][:10].variable, Variable(["time"], 1 + np.arange(10)) + ) + assert_equal(data["time.dayofyear"][0].variable, Variable([], 1)) def test_setitem(self): # assign a variable - var = Variable(['dim1'], np.random.randn(8)) + var = Variable(["dim1"], np.random.randn(8)) data1 = create_test_data() - data1['A'] = var + data1["A"] = var data2 = data1.copy() - data2['A'] = var + data2["A"] = var assert_identical(data1, data2) # assign a dataset array - dv = 2 * data2['A'] - data1['B'] = dv.variable - data2['B'] = dv + dv = 2 * data2["A"] + data1["B"] = dv.variable + data2["B"] = dv assert_identical(data1, data2) # can't assign an ND array without dimensions - with raises_regex(ValueError, - 'without explicit dimension names'): - data2['C'] = var.values.reshape(2, 4) + with raises_regex(ValueError, "without explicit dimension names"): + data2["C"] = var.values.reshape(2, 4) # but can assign a 1D array - data1['C'] = var.values - data2['C'] = ('C', var.values) + data1["C"] = var.values + data2["C"] = ("C", var.values) assert_identical(data1, data2) # can assign a scalar - data1['scalar'] = 0 - data2['scalar'] = ([], 0) + data1["scalar"] = 0 + data2["scalar"] = ([], 0) assert_identical(data1, data2) # can't use the same dimension name as a scalar var - with raises_regex(ValueError, 'already exists as a scalar'): - data1['newvar'] = ('scalar', [3, 4, 5]) + with raises_regex(ValueError, "already exists as a scalar"): + data1["newvar"] = ("scalar", [3, 4, 5]) # can't resize a used dimension - with raises_regex(ValueError, 'arguments without labels'): - data1['dim1'] = data1['dim1'][:5] + with raises_regex(ValueError, "arguments without labels"): + data1["dim1"] = data1["dim1"][:5] # override an existing value - data1['A'] = 3 * data2['A'] - assert_equal(data1['A'], 3 * data2['A']) + data1["A"] = 3 * data2["A"] + assert_equal(data1["A"], 3 * data2["A"]) with pytest.raises(NotImplementedError): - data1[{'x': 0}] = 0 + data1[{"x": 0}] = 0 def test_setitem_pandas(self): ds = self.make_example_math_dataset() - ds['x'] = np.arange(3) + ds["x"] = np.arange(3) ds_copy = ds.copy() - ds_copy['bar'] = ds['bar'].to_pandas() + ds_copy["bar"] = ds["bar"].to_pandas() assert_equal(ds, ds_copy) def test_setitem_auto_align(self): ds = Dataset() - ds['x'] = ('y', range(3)) - ds['y'] = 1 + np.arange(3) - expected = Dataset({'x': ('y', range(3)), 'y': 1 + np.arange(3)}) + ds["x"] = ("y", range(3)) + ds["y"] = 1 + np.arange(3) + expected = Dataset({"x": ("y", range(3)), "y": 1 + np.arange(3)}) assert_identical(ds, expected) - ds['y'] = DataArray(range(3), dims='y') - expected = Dataset({'x': ('y', range(3))}, {'y': range(3)}) + ds["y"] = DataArray(range(3), dims="y") + expected = Dataset({"x": ("y", range(3))}, {"y": range(3)}) assert_identical(ds, expected) - ds['x'] = DataArray([1, 2], coords=[('y', [0, 1])]) - expected = Dataset({'x': ('y', [1, 2, np.nan])}, {'y': range(3)}) + ds["x"] = DataArray([1, 2], coords=[("y", [0, 1])]) + expected = Dataset({"x": ("y", [1, 2, np.nan])}, {"y": range(3)}) assert_identical(ds, expected) - ds['x'] = 42 - expected = Dataset({'x': 42, 'y': range(3)}) + ds["x"] = 42 + expected = Dataset({"x": 42, "y": range(3)}) assert_identical(ds, expected) - ds['x'] = DataArray([4, 5, 6, 7], coords=[('y', [0, 1, 2, 3])]) - expected = Dataset({'x': ('y', [4, 5, 6])}, {'y': range(3)}) + ds["x"] = DataArray([4, 5, 6, 7], coords=[("y", [0, 1, 2, 3])]) + expected = Dataset({"x": ("y", [4, 5, 6])}, {"y": range(3)}) assert_identical(ds, expected) def test_setitem_with_coords(self): # Regression test for GH:2068 ds = create_test_data() - other = DataArray(np.arange(10), dims='dim3', - coords={'numbers': ('dim3', np.arange(10))}) + other = DataArray( + np.arange(10), dims="dim3", coords={"numbers": ("dim3", np.arange(10))} + ) expected = ds.copy() - expected['var3'] = other.drop('numbers') + expected["var3"] = other.drop("numbers") actual = ds.copy() - actual['var3'] = other + actual["var3"] = other assert_identical(expected, actual) - assert 'numbers' in other.coords # should not change other + assert "numbers" in other.coords # should not change other # with alignment - other = ds['var3'].isel(dim3=slice(1, -1)) - other['numbers'] = ('dim3', np.arange(8)) + other = ds["var3"].isel(dim3=slice(1, -1)) + other["numbers"] = ("dim3", np.arange(8)) actual = ds.copy() - actual['var3'] = other - assert 'numbers' in other.coords # should not change other + actual["var3"] = other + assert "numbers" in other.coords # should not change other expected = ds.copy() - expected['var3'] = ds['var3'].isel(dim3=slice(1, -1)) + expected["var3"] = ds["var3"].isel(dim3=slice(1, -1)) assert_identical(expected, actual) # with non-duplicate coords - other = ds['var3'].isel(dim3=slice(1, -1)) - other['numbers'] = ('dim3', np.arange(8)) - other['position'] = ('dim3', np.arange(8)) + other = ds["var3"].isel(dim3=slice(1, -1)) + other["numbers"] = ("dim3", np.arange(8)) + other["position"] = ("dim3", np.arange(8)) actual = ds.copy() - actual['var3'] = other - assert 'position' in actual - assert 'position' in other.coords + actual["var3"] = other + assert "position" in actual + assert "position" in other.coords # assigning a coordinate-only dataarray actual = ds.copy() - other = actual['numbers'] + other = actual["numbers"] other[0] = 10 - actual['numbers'] = other - assert actual['numbers'][0] == 10 + actual["numbers"] = other + assert actual["numbers"][0] == 10 # GH: 2099 - ds = Dataset({'var': ('x', [1, 2, 3])}, - coords={'x': [0, 1, 2], 'z1': ('x', [1, 2, 3]), - 'z2': ('x', [1, 2, 3])}) - ds['var'] = ds['var'] * 2 - assert np.allclose(ds['var'], [2, 4, 6]) + ds = Dataset( + {"var": ("x", [1, 2, 3])}, + coords={"x": [0, 1, 2], "z1": ("x", [1, 2, 3]), "z2": ("x", [1, 2, 3])}, + ) + ds["var"] = ds["var"] * 2 + assert np.allclose(ds["var"], [2, 4, 6]) def test_setitem_align_new_indexes(self): - ds = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]}) - ds['bar'] = DataArray([2, 3, 4], [('x', [1, 2, 3])]) - expected = Dataset({'foo': ('x', [1, 2, 3]), - 'bar': ('x', [np.nan, 2, 3])}, - {'x': [0, 1, 2]}) + ds = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) + ds["bar"] = DataArray([2, 3, 4], [("x", [1, 2, 3])]) + expected = Dataset( + {"foo": ("x", [1, 2, 3]), "bar": ("x", [np.nan, 2, 3])}, {"x": [0, 1, 2]} + ) assert_identical(ds, expected) def test_assign(self): ds = Dataset() actual = ds.assign(x=[0, 1, 2], y=2) - expected = Dataset({'x': [0, 1, 2], 'y': 2}) + expected = Dataset({"x": [0, 1, 2], "y": 2}) assert_identical(actual, expected) - assert list(actual.variables) == ['x', 'y'] + assert list(actual.variables) == ["x", "y"] assert_identical(ds, Dataset()) actual = actual.assign(y=lambda ds: ds.x ** 2) - expected = Dataset({'y': ('x', [0, 1, 4]), 'x': [0, 1, 2]}) + expected = Dataset({"y": ("x", [0, 1, 4]), "x": [0, 1, 2]}) assert_identical(actual, expected) actual = actual.assign_coords(z=2) - expected = Dataset({'y': ('x', [0, 1, 4])}, {'z': 2, 'x': [0, 1, 2]}) + expected = Dataset({"y": ("x", [0, 1, 4])}, {"z": 2, "x": [0, 1, 2]}) assert_identical(actual, expected) - ds = Dataset({'a': ('x', range(3))}, {'b': ('x', ['A'] * 2 + ['B'])}) - actual = ds.groupby('b').assign(c=lambda ds: 2 * ds.a) - expected = ds.merge({'c': ('x', [0, 2, 4])}) + ds = Dataset({"a": ("x", range(3))}, {"b": ("x", ["A"] * 2 + ["B"])}) + actual = ds.groupby("b").assign(c=lambda ds: 2 * ds.a) + expected = ds.merge({"c": ("x", [0, 2, 4])}) assert_identical(actual, expected) - actual = ds.groupby('b').assign(c=lambda ds: ds.a.sum()) - expected = ds.merge({'c': ('x', [1, 1, 2])}) + actual = ds.groupby("b").assign(c=lambda ds: ds.a.sum()) + expected = ds.merge({"c": ("x", [1, 1, 2])}) assert_identical(actual, expected) - actual = ds.groupby('b').assign_coords(c=lambda ds: ds.a.sum()) - expected = expected.set_coords('c') + actual = ds.groupby("b").assign_coords(c=lambda ds: ds.a.sum()) + expected = expected.set_coords("c") assert_identical(actual, expected) def test_assign_attrs(self): @@ -2892,130 +3124,131 @@ def test_assign_attrs(self): assert_identical(actual, expected) assert new.attrs == {} - expected.attrs['c'] = 3 - new_actual = actual.assign_attrs({'c': 3}) + expected.attrs["c"] = 3 + new_actual = actual.assign_attrs({"c": 3}) assert_identical(new_actual, expected) assert actual.attrs == dict(a=1, b=2) def test_assign_multiindex_level(self): data = create_test_multiindex() - with raises_regex(ValueError, 'conflicting MultiIndex'): + with raises_regex(ValueError, "conflicting MultiIndex"): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) # raise an Error when any level name is used as dimension GH:2299 with pytest.raises(ValueError): - data['y'] = ('level_1', [0, 1]) + data["y"] = ("level_1", [0, 1]) def test_merge_multiindex_level(self): data = create_test_multiindex() - other = Dataset({'z': ('level_1', [0, 1])}) # conflict dimension + other = Dataset({"z": ("level_1", [0, 1])}) # conflict dimension with pytest.raises(ValueError): data.merge(other) - other = Dataset({'level_1': ('x', [0, 1])}) # conflict variable name + other = Dataset({"level_1": ("x", [0, 1])}) # conflict variable name with pytest.raises(ValueError): data.merge(other) def test_setitem_original_non_unique_index(self): # regression test for GH943 - original = Dataset({'data': ('x', np.arange(5))}, - coords={'x': [0, 1, 2, 0, 1]}) - expected = Dataset({'data': ('x', np.arange(5))}, {'x': range(5)}) + original = Dataset({"data": ("x", np.arange(5))}, coords={"x": [0, 1, 2, 0, 1]}) + expected = Dataset({"data": ("x", np.arange(5))}, {"x": range(5)}) actual = original.copy() - actual['x'] = list(range(5)) + actual["x"] = list(range(5)) assert_identical(actual, expected) actual = original.copy() - actual['x'] = ('x', list(range(5))) + actual["x"] = ("x", list(range(5))) assert_identical(actual, expected) actual = original.copy() - actual.coords['x'] = list(range(5)) + actual.coords["x"] = list(range(5)) assert_identical(actual, expected) def test_setitem_both_non_unique_index(self): # regression test for GH956 - names = ['joaquin', 'manolo', 'joaquin'] + names = ["joaquin", "manolo", "joaquin"] values = np.random.randint(0, 256, (3, 4, 4)) - array = DataArray(values, dims=['name', 'row', 'column'], - coords=[names, range(4), range(4)]) - expected = Dataset({'first': array, 'second': array}) - actual = array.rename('first').to_dataset() - actual['second'] = array + array = DataArray( + values, dims=["name", "row", "column"], coords=[names, range(4), range(4)] + ) + expected = Dataset({"first": array, "second": array}) + actual = array.rename("first").to_dataset() + actual["second"] = array assert_identical(expected, actual) def test_setitem_multiindex_level(self): data = create_test_multiindex() - with raises_regex(ValueError, 'conflicting MultiIndex'): - data['level_1'] = range(4) + with raises_regex(ValueError, "conflicting MultiIndex"): + data["level_1"] = range(4) def test_delitem(self): data = create_test_data() all_items = set(data.variables) assert set(data.variables) == all_items - del data['var1'] - assert set(data.variables) == all_items - {'var1'} - del data['numbers'] - assert set(data.variables) == all_items - {'var1', 'numbers'} - assert 'numbers' not in data.coords + del data["var1"] + assert set(data.variables) == all_items - {"var1"} + del data["numbers"] + assert set(data.variables) == all_items - {"var1", "numbers"} + assert "numbers" not in data.coords expected = Dataset() - actual = Dataset({'y': ('x', [1, 2])}) - del actual['y'] + actual = Dataset({"y": ("x", [1, 2])}) + del actual["y"] assert_identical(expected, actual) def test_squeeze(self): - data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])}) - for args in [[], [['x']], [['x', 'z']]]: + data = Dataset({"foo": (["x", "y", "z"], [[[1], [2]]])}) + for args in [[], [["x"]], [["x", "z"]]]: + def get_args(v): return [set(args[0]) & set(v.dims)] if args else [] + expected = Dataset( - { - k: v.squeeze(*get_args(v)) - for k, v in data.variables.items() - } + {k: v.squeeze(*get_args(v)) for k, v in data.variables.items()} ) expected = expected.set_coords(data.coords) assert_identical(expected, data.squeeze(*args)) # invalid squeeze - with raises_regex(ValueError, 'cannot select a dimension'): - data.squeeze('y') + with raises_regex(ValueError, "cannot select a dimension"): + data.squeeze("y") def test_squeeze_drop(self): - data = Dataset({'foo': ('x', [1])}, {'x': [0]}) - expected = Dataset({'foo': 1}) + data = Dataset({"foo": ("x", [1])}, {"x": [0]}) + expected = Dataset({"foo": 1}) selected = data.squeeze(drop=True) assert_identical(expected, selected) - expected = Dataset({'foo': 1}, {'x': 0}) + expected = Dataset({"foo": 1}, {"x": 0}) selected = data.squeeze(drop=False) assert_identical(expected, selected) - data = Dataset({'foo': (('x', 'y'), [[1]])}, {'x': [0], 'y': [0]}) - expected = Dataset({'foo': 1}) + data = Dataset({"foo": (("x", "y"), [[1]])}, {"x": [0], "y": [0]}) + expected = Dataset({"foo": 1}) selected = data.squeeze(drop=True) assert_identical(expected, selected) - expected = Dataset({'foo': ('x', [1])}, {'x': [0]}) - selected = data.squeeze(dim='y', drop=True) + expected = Dataset({"foo": ("x", [1])}, {"x": [0]}) + selected = data.squeeze(dim="y", drop=True) assert_identical(expected, selected) - data = Dataset({'foo': (('x',), [])}, {'x': []}) + data = Dataset({"foo": (("x",), [])}, {"x": []}) selected = data.squeeze(drop=True) assert_identical(data, selected) def test_groupby(self): - data = Dataset({'z': (['x', 'y'], np.random.randn(3, 5))}, - {'x': ('x', list('abc')), - 'c': ('x', [0, 1, 0]), - 'y': range(5)}) - groupby = data.groupby('x') + data = Dataset( + {"z": (["x", "y"], np.random.randn(3, 5))}, + {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, + ) + groupby = data.groupby("x") assert len(groupby) == 3 - expected_groups = {'a': 0, 'b': 1, 'c': 2} + expected_groups = {"a": 0, "b": 1, "c": 2} assert groupby.groups == expected_groups - expected_items = [('a', data.isel(x=0)), - ('b', data.isel(x=1)), - ('c', data.isel(x=2))] + expected_items = [ + ("a", data.isel(x=0)), + ("b", data.isel(x=1)), + ("c", data.isel(x=2)), + ] for actual, expected in zip(groupby, expected_items): assert actual[0] == expected[0] assert_equal(actual[1], expected[1]) @@ -3023,82 +3256,93 @@ def test_groupby(self): def identity(x): return x - for k in ['x', 'c', 'y']: + for k in ["x", "c", "y"]: actual = data.groupby(k, squeeze=False).apply(identity) assert_equal(data, actual) def test_groupby_returns_new_type(self): - data = Dataset({'z': (['x', 'y'], np.random.randn(3, 5))}) + data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) - actual = data.groupby('x').apply(lambda ds: ds['z']) - expected = data['z'] + actual = data.groupby("x").apply(lambda ds: ds["z"]) + expected = data["z"] assert_identical(expected, actual) - actual = data['z'].groupby('x').apply(lambda x: x.to_dataset()) + actual = data["z"].groupby("x").apply(lambda x: x.to_dataset()) expected = data assert_identical(expected, actual) def test_groupby_iter(self): data = create_test_data() - for n, (t, sub) in enumerate(list(data.groupby('dim1'))[:3]): - assert data['dim1'][n] == t - assert_equal(data['var1'][n], sub['var1']) - assert_equal(data['var2'][n], sub['var2']) - assert_equal(data['var3'][:, n], sub['var3']) + for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]): + assert data["dim1"][n] == t + assert_equal(data["var1"][n], sub["var1"]) + assert_equal(data["var2"][n], sub["var2"]) + assert_equal(data["var3"][:, n], sub["var3"]) def test_groupby_errors(self): data = create_test_data() - with raises_regex(TypeError, '`group` must be'): + with raises_regex(TypeError, "`group` must be"): data.groupby(np.arange(10)) - with raises_regex(ValueError, 'length does not match'): - data.groupby(data['dim1'][:3]) + with raises_regex(ValueError, "length does not match"): + data.groupby(data["dim1"][:3]) with raises_regex(TypeError, "`group` must be"): - data.groupby(data.coords['dim1'].to_index()) + data.groupby(data.coords["dim1"].to_index()) def test_groupby_reduce(self): - data = Dataset({'xy': (['x', 'y'], np.random.randn(3, 4)), - 'xonly': ('x', np.random.randn(3)), - 'yonly': ('y', np.random.randn(4)), - 'letters': ('y', ['a', 'a', 'b', 'b'])}) - - expected = data.mean('y') - expected['yonly'] = expected['yonly'].variable.set_dims({'x': 3}) - actual = data.groupby('x').mean(ALL_DIMS) + data = Dataset( + { + "xy": (["x", "y"], np.random.randn(3, 4)), + "xonly": ("x", np.random.randn(3)), + "yonly": ("y", np.random.randn(4)), + "letters": ("y", ["a", "a", "b", "b"]), + } + ) + + expected = data.mean("y") + expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) + actual = data.groupby("x").mean(ALL_DIMS) assert_allclose(expected, actual) - actual = data.groupby('x').mean('y') + actual = data.groupby("x").mean("y") assert_allclose(expected, actual) - letters = data['letters'] - expected = Dataset({'xy': data['xy'].groupby(letters).mean(ALL_DIMS), - 'xonly': (data['xonly'].mean().variable - .set_dims({'letters': 2})), - 'yonly': data['yonly'].groupby(letters).mean()}) - actual = data.groupby('letters').mean(ALL_DIMS) + letters = data["letters"] + expected = Dataset( + { + "xy": data["xy"].groupby(letters).mean(ALL_DIMS), + "xonly": (data["xonly"].mean().variable.set_dims({"letters": 2})), + "yonly": data["yonly"].groupby(letters).mean(), + } + ) + actual = data.groupby("letters").mean(ALL_DIMS) assert_allclose(expected, actual) def test_groupby_warn(self): - data = Dataset({'xy': (['x', 'y'], np.random.randn(3, 4)), - 'xonly': ('x', np.random.randn(3)), - 'yonly': ('y', np.random.randn(4)), - 'letters': ('y', ['a', 'a', 'b', 'b'])}) + data = Dataset( + { + "xy": (["x", "y"], np.random.randn(3, 4)), + "xonly": ("x", np.random.randn(3)), + "yonly": ("y", np.random.randn(4)), + "letters": ("y", ["a", "a", "b", "b"]), + } + ) with pytest.warns(FutureWarning): - data.groupby('x').mean() + data.groupby("x").mean() def test_groupby_math(self): def reorder_dims(x): - return x.transpose('dim1', 'dim2', 'dim3', 'time') + return x.transpose("dim1", "dim2", "dim3", "time") ds = create_test_data() - ds['dim1'] = ds['dim1'] + ds["dim1"] = ds["dim1"] for squeeze in [True, False]: - grouped = ds.groupby('dim1', squeeze=squeeze) + grouped = ds.groupby("dim1", squeeze=squeeze) - expected = reorder_dims(ds + ds.coords['dim1']) - actual = grouped + ds.coords['dim1'] + expected = reorder_dims(ds + ds.coords["dim1"]) + actual = grouped + ds.coords["dim1"] assert_identical(expected, reorder_dims(actual)) - actual = ds.coords['dim1'] + grouped + actual = ds.coords["dim1"] + grouped assert_identical(expected, reorder_dims(actual)) ds2 = 2 * ds @@ -3109,56 +3353,60 @@ def reorder_dims(x): actual = ds2 + grouped assert_identical(expected, reorder_dims(actual)) - grouped = ds.groupby('numbers') - zeros = DataArray([0, 0, 0, 0], [('numbers', range(4))]) - expected = ((ds + Variable('dim3', np.zeros(10))) - .transpose('dim3', 'dim1', 'dim2', 'time')) + grouped = ds.groupby("numbers") + zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))]) + expected = (ds + Variable("dim3", np.zeros(10))).transpose( + "dim3", "dim1", "dim2", "time" + ) actual = grouped + zeros assert_equal(expected, actual) actual = zeros + grouped assert_equal(expected, actual) - with raises_regex(ValueError, 'incompat.* grouped binary'): + with raises_regex(ValueError, "incompat.* grouped binary"): grouped + ds - with raises_regex(ValueError, 'incompat.* grouped binary'): + with raises_regex(ValueError, "incompat.* grouped binary"): ds + grouped - with raises_regex(TypeError, 'only support binary ops'): + with raises_regex(TypeError, "only support binary ops"): grouped + 1 - with raises_regex(TypeError, 'only support binary ops'): + with raises_regex(TypeError, "only support binary ops"): grouped + grouped - with raises_regex(TypeError, 'in-place operations'): + with raises_regex(TypeError, "in-place operations"): ds += grouped - ds = Dataset({'x': ('time', np.arange(100)), - 'time': pd.date_range('2000-01-01', periods=100)}) - with raises_regex(ValueError, 'incompat.* grouped binary'): - ds + ds.groupby('time.month') + ds = Dataset( + { + "x": ("time", np.arange(100)), + "time": pd.date_range("2000-01-01", periods=100), + } + ) + with raises_regex(ValueError, "incompat.* grouped binary"): + ds + ds.groupby("time.month") def test_groupby_math_virtual(self): - ds = Dataset({'x': ('t', [1, 2, 3])}, - {'t': pd.date_range('20100101', periods=3)}) - grouped = ds.groupby('t.day') + ds = Dataset( + {"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)} + ) + grouped = ds.groupby("t.day") actual = grouped - grouped.mean(ALL_DIMS) - expected = Dataset({'x': ('t', [0, 0, 0])}, - ds[['t', 't.day']]) + expected = Dataset({"x": ("t", [0, 0, 0])}, ds[["t", "t.day"]]) assert_identical(actual, expected) def test_groupby_nan(self): # nan should be excluded from groupby - ds = Dataset({'foo': ('x', [1, 2, 3, 4])}, - {'bar': ('x', [1, 1, 2, np.nan])}) - actual = ds.groupby('bar').mean(ALL_DIMS) - expected = Dataset({'foo': ('bar', [1.5, 3]), 'bar': [1, 2]}) + ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])}) + actual = ds.groupby("bar").mean(ALL_DIMS) + expected = Dataset({"foo": ("bar", [1.5, 3]), "bar": [1, 2]}) assert_identical(actual, expected) def test_groupby_order(self): # groupby should preserve variables order ds = Dataset() - for vn in ['a', 'b', 'c']: - ds[vn] = DataArray(np.arange(10), dims=['t']) + for vn in ["a", "b", "c"]: + ds[vn] = DataArray(np.arange(10), dims=["t"]) data_vars_ref = list(ds.data_vars.keys()) - ds = ds.groupby('t').mean(ALL_DIMS) + ds = ds.groupby("t").mean(ALL_DIMS) data_vars = list(ds.data_vars.keys()) assert data_vars == data_vars_ref # coords are now at the end of the list, so the test below fails @@ -3167,51 +3415,67 @@ def test_groupby_order(self): # self.assertEqual(all_vars, all_vars_ref) def test_resample_and_first(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) - actual = ds.resample(time='1D').first(keep_attrs=True) + actual = ds.resample(time="1D").first(keep_attrs=True) expected = ds.isel(time=[0, 4, 8]) assert_identical(expected, actual) # upsampling - expected_time = pd.date_range('2000-01-01', freq='3H', periods=19) + expected_time = pd.date_range("2000-01-01", freq="3H", periods=19) expected = ds.reindex(time=expected_time) - actual = ds.resample(time='3H') - for how in ['mean', 'sum', 'first', 'last', ]: + actual = ds.resample(time="3H") + for how in ["mean", "sum", "first", "last"]: method = getattr(actual, how) result = method() assert_equal(expected, result) - for method in [np.mean, ]: + for method in [np.mean]: result = actual.reduce(method) assert_equal(expected, result) def test_resample_min_count(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) # inject nan - ds['foo'] = xr.where(ds['foo'] > 2.0, np.nan, ds['foo']) - - actual = ds.resample(time='1D').sum(min_count=1) - expected = xr.concat([ - ds.isel(time=slice(i * 4, (i + 1) * 4)).sum('time', min_count=1) - for i in range(3)], dim=actual['time']) + ds["foo"] = xr.where(ds["foo"] > 2.0, np.nan, ds["foo"]) + + actual = ds.resample(time="1D").sum(min_count=1) + expected = xr.concat( + [ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum("time", min_count=1) + for i in range(3) + ], + dim=actual["time"], + ) assert_equal(expected, actual) def test_resample_by_mean_with_keep_attrs(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) - ds.attrs['dsmeta'] = 'dsdata' - - resampled_ds = ds.resample(time='1D').mean(keep_attrs=True) - actual = resampled_ds['bar'].attrs - expected = ds['bar'].attrs + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").mean(keep_attrs=True) + actual = resampled_ds["bar"].attrs + expected = ds["bar"].attrs assert expected == actual actual = resampled_ds.attrs @@ -3219,129 +3483,142 @@ def test_resample_by_mean_with_keep_attrs(self): assert expected == actual def test_resample_loffset(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) - ds.attrs['dsmeta'] = 'dsdata' + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" - actual = ds.resample(time='24H', loffset='-12H').mean('time').time - expected = xr.DataArray(ds.bar.to_series() - .resample('24H', loffset='-12H').mean()).time + actual = ds.resample(time="24H", loffset="-12H").mean("time").time + expected = xr.DataArray( + ds.bar.to_series().resample("24H", loffset="-12H").mean() + ).time assert_identical(expected, actual) def test_resample_by_mean_discarding_attrs(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) - ds.attrs['dsmeta'] = 'dsdata' + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" - resampled_ds = ds.resample(time='1D').mean(keep_attrs=False) + resampled_ds = ds.resample(time="1D").mean(keep_attrs=False) - assert resampled_ds['bar'].attrs == {} + assert resampled_ds["bar"].attrs == {} assert resampled_ds.attrs == {} def test_resample_by_last_discarding_attrs(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) - ds.attrs['dsmeta'] = 'dsdata' + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" - resampled_ds = ds.resample(time='1D').last(keep_attrs=False) + resampled_ds = ds.resample(time="1D").last(keep_attrs=False) - assert resampled_ds['bar'].attrs == {} + assert resampled_ds["bar"].attrs == {} assert resampled_ds.attrs == {} @requires_scipy def test_resample_drop_nondim_coords(self): xs = np.arange(6) ys = np.arange(3) - times = pd.date_range('2000-01-01', freq='6H', periods=5) + times = pd.date_range("2000-01-01", freq="6H", periods=5) data = np.tile(np.arange(5), (6, 3, 1)) xx, yy = np.meshgrid(xs * 5, ys * 2.5) tt = np.arange(len(times), dtype=int) - array = DataArray(data, - {'time': times, 'x': xs, 'y': ys}, - ('x', 'y', 'time')) - xcoord = DataArray(xx.T, {'x': xs, 'y': ys}, ('x', 'y')) - ycoord = DataArray(yy.T, {'x': xs, 'y': ys}, ('x', 'y')) - tcoord = DataArray(tt, {'time': times}, ('time', )) - ds = Dataset({'data': array, 'xc': xcoord, - 'yc': ycoord, 'tc': tcoord}) - ds = ds.set_coords(['xc', 'yc', 'tc']) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) + ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) + tcoord = DataArray(tt, {"time": times}, ("time",)) + ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) + ds = ds.set_coords(["xc", "yc", "tc"]) # Re-sample - actual = ds.resample(time="12H").mean('time') - assert 'tc' not in actual.coords + actual = ds.resample(time="12H").mean("time") + assert "tc" not in actual.coords # Up-sample - filling actual = ds.resample(time="1H").ffill() - assert 'tc' not in actual.coords + assert "tc" not in actual.coords # Up-sample - interpolation - actual = ds.resample(time="1H").interpolate('linear') - assert 'tc' not in actual.coords + actual = ds.resample(time="1H").interpolate("linear") + assert "tc" not in actual.coords def test_resample_old_api(self): - times = pd.date_range('2000-01-01', freq='6H', periods=10) - ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), - 'bar': ('time', np.random.randn(10), {'meta': 'data'}), - 'time': times}) + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) - with raises_regex(TypeError, r'resample\(\) no longer supports'): - ds.resample('1D', 'time') + with raises_regex(TypeError, r"resample\(\) no longer supports"): + ds.resample("1D", "time") - with raises_regex(TypeError, r'resample\(\) no longer supports'): - ds.resample('1D', dim='time', how='mean') + with raises_regex(TypeError, r"resample\(\) no longer supports"): + ds.resample("1D", dim="time", how="mean") - with raises_regex(TypeError, r'resample\(\) no longer supports'): - ds.resample('1D', dim='time') + with raises_regex(TypeError, r"resample\(\) no longer supports"): + ds.resample("1D", dim="time") def test_ds_resample_apply_func_args(self): + def func(arg1, arg2, arg3=0.0): + return arg1.mean("time") + arg2 + arg3 - def func(arg1, arg2, arg3=0.): - return arg1.mean('time') + arg2 + arg3 - - times = pd.date_range('2000', freq='D', periods=3) - ds = xr.Dataset({'foo': ('time', [1., 1., 1.]), - 'time': times}) - expected = xr.Dataset({'foo': ('time', [3., 3., 3.]), - 'time': times}) - actual = ds.resample(time='D').apply(func, args=(1.,), arg3=1.) + times = pd.date_range("2000", freq="D", periods=3) + ds = xr.Dataset({"foo": ("time", [1.0, 1.0, 1.0]), "time": times}) + expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) + actual = ds.resample(time="D").apply(func, args=(1.0,), arg3=1.0) assert_identical(expected, actual) def test_to_array(self): - ds = Dataset(OrderedDict([('a', 1), ('b', ('x', [1, 2, 3]))]), - coords={'c': 42}, attrs={'Conventions': 'None'}) + ds = Dataset( + OrderedDict([("a", 1), ("b", ("x", [1, 2, 3]))]), + coords={"c": 42}, + attrs={"Conventions": "None"}, + ) data = [[1, 1, 1], [1, 2, 3]] - coords = {'c': 42, 'variable': ['a', 'b']} - dims = ('variable', 'x') + coords = {"c": 42, "variable": ["a", "b"]} + dims = ("variable", "x") expected = DataArray(data, coords, dims, attrs=ds.attrs) actual = ds.to_array() assert_identical(expected, actual) - actual = ds.to_array('abc', name='foo') - expected = expected.rename({'variable': 'abc'}).rename('foo') + actual = ds.to_array("abc", name="foo") + expected = expected.rename({"variable": "abc"}).rename("foo") assert_identical(expected, actual) def test_to_and_from_dataframe(self): x = np.random.randn(10) y = np.random.randn(10) - t = list('abcdefghij') - ds = Dataset(OrderedDict([('a', ('t', x)), - ('b', ('t', y)), - ('t', ('t', t))])) - expected = pd.DataFrame(np.array([x, y]).T, columns=['a', 'b'], - index=pd.Index(t, name='t')) + t = list("abcdefghij") + ds = Dataset(OrderedDict([("a", ("t", x)), ("b", ("t", y)), ("t", ("t", t))])) + expected = pd.DataFrame( + np.array([x, y]).T, columns=["a", "b"], index=pd.Index(t, name="t") + ) actual = ds.to_dataframe() # use the .equals method to check all DataFrame metadata assert expected.equals(actual), (expected, actual) # verify coords are included - actual = ds.set_coords('b').to_dataframe() + actual = ds.set_coords("b").to_dataframe() assert expected.equals(actual), (expected, actual) # check roundtrip @@ -3349,60 +3626,59 @@ def test_to_and_from_dataframe(self): # test a case with a MultiIndex w = np.random.randn(2, 3) - ds = Dataset({'w': (('x', 'y'), w)}) - ds['y'] = ('y', list('abc')) + ds = Dataset({"w": (("x", "y"), w)}) + ds["y"] = ("y", list("abc")) exp_index = pd.MultiIndex.from_arrays( - [[0, 0, 0, 1, 1, 1], ['a', 'b', 'c', 'a', 'b', 'c']], - names=['x', 'y']) - expected = pd.DataFrame(w.reshape(-1), columns=['w'], index=exp_index) + [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] + ) + expected = pd.DataFrame(w.reshape(-1), columns=["w"], index=exp_index) actual = ds.to_dataframe() assert expected.equals(actual) # check roundtrip - assert_identical(ds.assign_coords(x=[0, 1]), - Dataset.from_dataframe(actual)) + assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual)) # check pathological cases df = pd.DataFrame([1]) actual = Dataset.from_dataframe(df) - expected = Dataset({0: ('index', [1])}, {'index': [0]}) + expected = Dataset({0: ("index", [1])}, {"index": [0]}) assert_identical(expected, actual) df = pd.DataFrame() actual = Dataset.from_dataframe(df) - expected = Dataset(coords={'index': []}) + expected = Dataset(coords={"index": []}) assert_identical(expected, actual) # GH697 - df = pd.DataFrame({'A': []}) + df = pd.DataFrame({"A": []}) actual = Dataset.from_dataframe(df) - expected = Dataset({'A': DataArray([], dims=('index',))}, - {'index': []}) + expected = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) assert_identical(expected, actual) # regression test for GH278 # use int64 to ensure consistent results for the pandas .equals method # on windows (which requires the same dtype) - ds = Dataset({'x': pd.Index(['bar']), - 'a': ('y', np.array([1], 'int64'))}).isel(x=0) + ds = Dataset({"x": pd.Index(["bar"]), "a": ("y", np.array([1], "int64"))}).isel( + x=0 + ) # use .loc to ensure consistent results on Python 3 - actual = ds.to_dataframe().loc[:, ['a', 'x']] - expected = pd.DataFrame([[1, 'bar']], index=pd.Index([0], name='y'), - columns=['a', 'x']) + actual = ds.to_dataframe().loc[:, ["a", "x"]] + expected = pd.DataFrame( + [[1, "bar"]], index=pd.Index([0], name="y"), columns=["a", "x"] + ) assert expected.equals(actual), (expected, actual) - ds = Dataset({'x': np.array([0], 'int64'), - 'y': np.array([1], 'int64')}) + ds = Dataset({"x": np.array([0], "int64"), "y": np.array([1], "int64")}) actual = ds.to_dataframe() - idx = pd.MultiIndex.from_arrays([[0], [1]], names=['x', 'y']) + idx = pd.MultiIndex.from_arrays([[0], [1]], names=["x", "y"]) expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) def test_to_and_from_empty_dataframe(self): # GH697 - expected = pd.DataFrame({'foo': []}) + expected = pd.DataFrame({"foo": []}) ds = Dataset.from_dataframe(expected) - assert len(ds['foo']) == 0 + assert len(ds["foo"]) == 0 actual = ds.to_dataframe() assert len(actual) == 0 assert expected.equals(actual) @@ -3410,24 +3686,25 @@ def test_to_and_from_empty_dataframe(self): def test_from_dataframe_non_unique_columns(self): # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) - df.columns = ['foo', 'foo'] - with raises_regex(ValueError, 'non-unique columns'): + df.columns = ["foo", "foo"] + with raises_regex(ValueError, "non-unique columns"): Dataset.from_dataframe(df) def test_convert_dataframe_with_many_types_and_multiindex(self): # regression test for GH737 - df = pd.DataFrame({'a': list('abc'), - 'b': list(range(1, 4)), - 'c': np.arange(3, 6).astype('u1'), - 'd': np.arange(4.0, 7.0, dtype='float64'), - 'e': [True, False, True], - 'f': pd.Categorical(list('abc')), - 'g': pd.date_range('20130101', periods=3), - 'h': pd.date_range('20130101', - periods=3, - tz='US/Eastern')}) - df.index = pd.MultiIndex.from_product([['a'], range(3)], - names=['one', 'two']) + df = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + "c": np.arange(3, 6).astype("u1"), + "d": np.arange(4.0, 7.0, dtype="float64"), + "e": [True, False, True], + "f": pd.Categorical(list("abc")), + "g": pd.date_range("20130101", periods=3), + "h": pd.date_range("20130101", periods=3, tz="US/Eastern"), + } + ) + df.index = pd.MultiIndex.from_product([["a"], range(3)], names=["one", "two"]) roundtripped = Dataset.from_dataframe(df).to_dataframe() # we can't do perfectly, but we should be at least as faithful as # np.asarray @@ -3444,21 +3721,17 @@ def test_to_and_from_dict(self): # b (t) float64 1.32 0.1954 1.91 1.39 0.519 -0.2772 ... x = np.random.randn(10) y = np.random.randn(10) - t = list('abcdefghij') - ds = Dataset(OrderedDict([('a', ('t', x)), - ('b', ('t', y)), - ('t', ('t', t))])) - expected = {'coords': {'t': {'dims': ('t',), - 'data': t, - 'attrs': {}}}, - 'attrs': {}, - 'dims': {'t': 10}, - 'data_vars': {'a': {'dims': ('t',), - 'data': x.tolist(), - 'attrs': {}}, - 'b': {'dims': ('t',), - 'data': y.tolist(), - 'attrs': {}}}} + t = list("abcdefghij") + ds = Dataset(OrderedDict([("a", ("t", x)), ("b", ("t", y)), ("t", ("t", t))])) + expected = { + "coords": {"t": {"dims": ("t",), "data": t, "attrs": {}}}, + "attrs": {}, + "dims": {"t": 10}, + "data_vars": { + "a": {"dims": ("t",), "data": x.tolist(), "attrs": {}}, + "b": {"dims": ("t",), "data": y.tolist(), "attrs": {}}, + }, + } actual = ds.to_dict() @@ -3470,21 +3743,18 @@ def test_to_and_from_dict(self): # check the data=False option expected_no_data = expected.copy() - del expected_no_data['coords']['t']['data'] - del expected_no_data['data_vars']['a']['data'] - del expected_no_data['data_vars']['b']['data'] - endiantype = 'U1' - expected_no_data['coords']['t'].update({'dtype': endiantype, - 'shape': (10,)}) - expected_no_data['data_vars']['a'].update({'dtype': 'float64', - 'shape': (10,)}) - expected_no_data['data_vars']['b'].update({'dtype': 'float64', - 'shape': (10,)}) + del expected_no_data["coords"]["t"]["data"] + del expected_no_data["data_vars"]["a"]["data"] + del expected_no_data["data_vars"]["b"]["data"] + endiantype = "U1" + expected_no_data["coords"]["t"].update({"dtype": endiantype, "shape": (10,)}) + expected_no_data["data_vars"]["a"].update({"dtype": "float64", "shape": (10,)}) + expected_no_data["data_vars"]["b"].update({"dtype": "float64", "shape": (10,)}) actual_no_data = ds.to_dict(data=False) assert expected_no_data == actual_no_data # verify coords are included roundtrip - expected_ds = ds.set_coords('b') + expected_ds = ds.set_coords("b") actual = Dataset.from_dict(expected_ds.to_dict()) assert_identical(expected_ds, actual) @@ -3493,35 +3763,45 @@ def test_to_and_from_dict(self): # this one has no attrs field, the dims are strings, and x, y are # np.arrays - d = {'coords': {'t': {'dims': 't', 'data': t}}, - 'dims': 't', - 'data_vars': {'a': {'dims': 't', 'data': x}, - 'b': {'dims': 't', 'data': y}}} + d = { + "coords": {"t": {"dims": "t", "data": t}}, + "dims": "t", + "data_vars": {"a": {"dims": "t", "data": x}, "b": {"dims": "t", "data": y}}, + } assert_identical(ds, Dataset.from_dict(d)) # this is kind of a flattened version with no coords, or data_vars - d = {'a': {'dims': 't', 'data': x}, - 't': {'data': t, 'dims': 't'}, - 'b': {'dims': 't', 'data': y}} + d = { + "a": {"dims": "t", "data": x}, + "t": {"data": t, "dims": "t"}, + "b": {"dims": "t", "data": y}, + } assert_identical(ds, Dataset.from_dict(d)) # this one is missing some necessary information - d = {'a': {'data': x}, - 't': {'data': t, 'dims': 't'}, - 'b': {'dims': 't', 'data': y}} - with raises_regex(ValueError, "cannot convert dict " - "without the key 'dims'"): + d = { + "a": {"data": x}, + "t": {"data": t, "dims": "t"}, + "b": {"dims": "t", "data": y}, + } + with raises_regex(ValueError, "cannot convert dict " "without the key 'dims'"): Dataset.from_dict(d) def test_to_and_from_dict_with_time_dim(self): x = np.random.randn(10, 3) y = np.random.randn(10, 3) - t = pd.date_range('20130101', periods=10) + t = pd.date_range("20130101", periods=10) lat = [77.7, 83.2, 76] - ds = Dataset(OrderedDict([('a', (['t', 'lat'], x)), - ('b', (['t', 'lat'], y)), - ('t', ('t', t)), - ('lat', ('lat', lat))])) + ds = Dataset( + OrderedDict( + [ + ("a", (["t", "lat"], x)), + ("b", (["t", "lat"], y)), + ("t", ("t", t)), + ("lat", ("lat", lat)), + ] + ) + ) roundtripped = Dataset.from_dict(ds.to_dict()) assert_identical(ds, roundtripped) @@ -3529,14 +3809,20 @@ def test_to_and_from_dict_with_nan_nat(self): x = np.random.randn(10, 3) y = np.random.randn(10, 3) y[2] = np.nan - t = pd.Series(pd.date_range('20130101', periods=10)) + t = pd.Series(pd.date_range("20130101", periods=10)) t[2] = np.nan lat = [77.7, 83.2, 76] - ds = Dataset(OrderedDict([('a', (['t', 'lat'], x)), - ('b', (['t', 'lat'], y)), - ('t', ('t', t)), - ('lat', ('lat', lat))])) + ds = Dataset( + OrderedDict( + [ + ("a", (["t", "lat"], x)), + ("b", (["t", "lat"], y)), + ("t", ("t", t)), + ("lat", ("lat", lat)), + ] + ) + ) roundtripped = Dataset.from_dict(ds.to_dict()) assert_identical(ds, roundtripped) @@ -3544,20 +3830,26 @@ def test_to_dict_with_numpy_attrs(self): # this doesn't need to roundtrip x = np.random.randn(10) y = np.random.randn(10) - t = list('abcdefghij') - attrs = {'created': np.float64(1998), - 'coords': np.array([37, -110.1, 100]), - 'maintainer': 'bar'} - ds = Dataset(OrderedDict([('a', ('t', x, attrs)), - ('b', ('t', y, attrs)), - ('t', ('t', t))])) - expected_attrs = {'created': attrs['created'].item(), - 'coords': attrs['coords'].tolist(), - 'maintainer': 'bar'} + t = list("abcdefghij") + attrs = { + "created": np.float64(1998), + "coords": np.array([37, -110.1, 100]), + "maintainer": "bar", + } + ds = Dataset( + OrderedDict( + [("a", ("t", x, attrs)), ("b", ("t", y, attrs)), ("t", ("t", t))] + ) + ) + expected_attrs = { + "created": attrs["created"].item(), + "coords": attrs["coords"].tolist(), + "maintainer": "bar", + } actual = ds.to_dict() # check that they are identical - assert expected_attrs == actual['data_vars']['a']['attrs'] + assert expected_attrs == actual["data_vars"]["a"]["attrs"] def test_pickle(self): data = create_test_data() @@ -3575,7 +3867,7 @@ def test_lazy_load(self): with pytest.raises(UnexpectedDataAccess): ds.load() with pytest.raises(UnexpectedDataAccess): - ds['var1'].values + ds["var1"].values # these should not raise UnexpectedDataAccess: ds.isel(time=10) @@ -3586,84 +3878,83 @@ def test_dropna(self): x[::2, 0] = np.nan y = np.random.randn(4) y[-1] = np.nan - ds = Dataset({'foo': (('a', 'b'), x), 'bar': (('b', y))}) + ds = Dataset({"foo": (("a", "b"), x), "bar": (("b", y))}) expected = ds.isel(a=slice(1, None, 2)) - actual = ds.dropna('a') + actual = ds.dropna("a") assert_identical(actual, expected) expected = ds.isel(b=slice(1, 3)) - actual = ds.dropna('b') + actual = ds.dropna("b") assert_identical(actual, expected) - actual = ds.dropna('b', subset=['foo', 'bar']) + actual = ds.dropna("b", subset=["foo", "bar"]) assert_identical(actual, expected) expected = ds.isel(b=slice(1, None)) - actual = ds.dropna('b', subset=['foo']) + actual = ds.dropna("b", subset=["foo"]) assert_identical(actual, expected) expected = ds.isel(b=slice(3)) - actual = ds.dropna('b', subset=['bar']) + actual = ds.dropna("b", subset=["bar"]) assert_identical(actual, expected) - actual = ds.dropna('a', subset=[]) + actual = ds.dropna("a", subset=[]) assert_identical(actual, ds) - actual = ds.dropna('a', subset=['bar']) + actual = ds.dropna("a", subset=["bar"]) assert_identical(actual, ds) - actual = ds.dropna('a', how='all') + actual = ds.dropna("a", how="all") assert_identical(actual, ds) - actual = ds.dropna('b', how='all', subset=['bar']) + actual = ds.dropna("b", how="all", subset=["bar"]) expected = ds.isel(b=[0, 1, 2]) assert_identical(actual, expected) - actual = ds.dropna('b', thresh=1, subset=['bar']) + actual = ds.dropna("b", thresh=1, subset=["bar"]) assert_identical(actual, expected) - actual = ds.dropna('b', thresh=2) + actual = ds.dropna("b", thresh=2) assert_identical(actual, ds) - actual = ds.dropna('b', thresh=4) + actual = ds.dropna("b", thresh=4) expected = ds.isel(b=[1, 2, 3]) assert_identical(actual, expected) - actual = ds.dropna('a', thresh=3) + actual = ds.dropna("a", thresh=3) expected = ds.isel(a=[1, 3]) assert_identical(actual, ds) - with raises_regex(ValueError, 'a single dataset dimension'): - ds.dropna('foo') - with raises_regex(ValueError, 'invalid how'): - ds.dropna('a', how='somehow') - with raises_regex(TypeError, 'must specify how or thresh'): - ds.dropna('a', how=None) + with raises_regex(ValueError, "a single dataset dimension"): + ds.dropna("foo") + with raises_regex(ValueError, "invalid how"): + ds.dropna("a", how="somehow") + with raises_regex(TypeError, "must specify how or thresh"): + ds.dropna("a", how=None) def test_fillna(self): - ds = Dataset({'a': ('x', [np.nan, 1, np.nan, 3])}, - {'x': [0, 1, 2, 3]}) + ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) # fill with -1 actual = ds.fillna(-1) - expected = Dataset({'a': ('x', [-1, 1, -1, 3])}, {'x': [0, 1, 2, 3]}) + expected = Dataset({"a": ("x", [-1, 1, -1, 3])}, {"x": [0, 1, 2, 3]}) assert_identical(expected, actual) - actual = ds.fillna({'a': -1}) + actual = ds.fillna({"a": -1}) assert_identical(expected, actual) - other = Dataset({'a': -1}) + other = Dataset({"a": -1}) actual = ds.fillna(other) assert_identical(expected, actual) - actual = ds.fillna({'a': other.a}) + actual = ds.fillna({"a": other.a}) assert_identical(expected, actual) # fill with range(4) - b = DataArray(range(4), coords=[('x', range(4))]) + b = DataArray(range(4), coords=[("x", range(4))]) actual = ds.fillna(b) - expected = b.rename('a').to_dataset() + expected = b.rename("a").to_dataset() assert_identical(expected, actual) actual = ds.fillna(expected) @@ -3676,56 +3967,57 @@ def test_fillna(self): assert_identical(expected, actual) # okay to only include some data variables - ds['b'] = np.nan - actual = ds.fillna({'a': -1}) - expected = Dataset({'a': ('x', [-1, 1, -1, 3]), 'b': np.nan}, - {'x': [0, 1, 2, 3]}) + ds["b"] = np.nan + actual = ds.fillna({"a": -1}) + expected = Dataset( + {"a": ("x", [-1, 1, -1, 3]), "b": np.nan}, {"x": [0, 1, 2, 3]} + ) assert_identical(expected, actual) # but new data variables is not okay - with raises_regex(ValueError, 'must be contained'): - ds.fillna({'x': 0}) + with raises_regex(ValueError, "must be contained"): + ds.fillna({"x": 0}) # empty argument should be OK result = ds.fillna({}) assert_identical(ds, result) - result = ds.fillna(Dataset(coords={'c': 42})) + result = ds.fillna(Dataset(coords={"c": 42})) expected = ds.assign_coords(c=42) assert_identical(expected, result) # groupby - expected = Dataset({'a': ('x', range(4))}, {'x': [0, 1, 2, 3]}) + expected = Dataset({"a": ("x", range(4))}, {"x": [0, 1, 2, 3]}) for target in [ds, expected]: - target.coords['b'] = ('x', [0, 0, 1, 1]) - actual = ds.groupby('b').fillna(DataArray([0, 2], dims='b')) + target.coords["b"] = ("x", [0, 0, 1, 1]) + actual = ds.groupby("b").fillna(DataArray([0, 2], dims="b")) assert_identical(expected, actual) - actual = ds.groupby('b').fillna(Dataset({'a': ('b', [0, 2])})) + actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) assert_identical(expected, actual) # attrs with groupby - ds.attrs['attr'] = 'ds' - ds.a.attrs['attr'] = 'da' - actual = ds.groupby('b').fillna(Dataset({'a': ('b', [0, 2])})) + ds.attrs["attr"] = "ds" + ds.a.attrs["attr"] = "da" + actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) assert actual.attrs == ds.attrs - assert actual.a.name == 'a' + assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs - da = DataArray(range(5), name='a', attrs={'attr': 'da'}) + da = DataArray(range(5), name="a", attrs={"attr": "da"}) actual = da.fillna(1) - assert actual.name == 'a' + assert actual.name == "a" assert actual.attrs == da.attrs - ds = Dataset({'a': da}, attrs={'attr': 'ds'}) - actual = ds.fillna({'a': 1}) + ds = Dataset({"a": da}, attrs={"attr": "ds"}) + actual = ds.fillna({"a": 1}) assert actual.attrs == ds.attrs - assert actual.a.name == 'a' + assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs def test_where(self): - ds = Dataset({'a': ('x', range(5))}) - expected = Dataset({'a': ('x', [np.nan, np.nan, 2, 3, 4])}) + ds = Dataset({"a": ("x", range(5))}) + expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])}) actual = ds.where(ds > 1) assert_identical(expected, actual) @@ -3739,47 +4031,47 @@ def test_where(self): assert_identical(ds, actual) expected = ds.copy(deep=True) - expected['a'].values = [np.nan] * 5 + expected["a"].values = [np.nan] * 5 actual = ds.where(False) assert_identical(expected, actual) # 2d - ds = Dataset({'a': (('x', 'y'), [[0, 1], [2, 3]])}) - expected = Dataset({'a': (('x', 'y'), [[np.nan, 1], [2, 3]])}) + ds = Dataset({"a": (("x", "y"), [[0, 1], [2, 3]])}) + expected = Dataset({"a": (("x", "y"), [[np.nan, 1], [2, 3]])}) actual = ds.where(ds > 0) assert_identical(expected, actual) # groupby - ds = Dataset({'a': ('x', range(5))}, {'c': ('x', [0, 0, 1, 1, 1])}) - cond = Dataset({'a': ('c', [True, False])}) + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + cond = Dataset({"a": ("c", [True, False])}) expected = ds.copy(deep=True) - expected['a'].values = [0, 1] + [np.nan] * 3 - actual = ds.groupby('c').where(cond) + expected["a"].values = [0, 1] + [np.nan] * 3 + actual = ds.groupby("c").where(cond) assert_identical(expected, actual) # attrs with groupby - ds.attrs['attr'] = 'ds' - ds.a.attrs['attr'] = 'da' - actual = ds.groupby('c').where(cond) + ds.attrs["attr"] = "ds" + ds.a.attrs["attr"] = "da" + actual = ds.groupby("c").where(cond) assert actual.attrs == ds.attrs - assert actual.a.name == 'a' + assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs # attrs - da = DataArray(range(5), name='a', attrs={'attr': 'da'}) + da = DataArray(range(5), name="a", attrs={"attr": "da"}) actual = da.where(da.values > 1) - assert actual.name == 'a' + assert actual.name == "a" assert actual.attrs == da.attrs - ds = Dataset({'a': da}, attrs={'attr': 'ds'}) + ds = Dataset({"a": da}, attrs={"attr": "ds"}) actual = ds.where(ds > 0) assert actual.attrs == ds.attrs - assert actual.a.name == 'a' + assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs def test_where_other(self): - ds = Dataset({'a': ('x', range(5))}, {'x': range(5)}) - expected = Dataset({'a': ('x', [-1, -1, 2, 3, 4])}, {'x': range(5)}) + ds = Dataset({"a": ("x", range(5))}, {"x": range(5)}) + expected = Dataset({"a": ("x", [-1, -1, 2, 3, 4])}, {"x": range(5)}) actual = ds.where(ds > 1, -1) assert_equal(expected, actual) assert actual.a.dtype == int @@ -3798,14 +4090,14 @@ def test_where_drop(self): # 1d # data array case - array = DataArray(range(5), coords=[range(5)], dims=['x']) - expected = DataArray(range(5)[2:], coords=[range(5)[2:]], dims=['x']) + array = DataArray(range(5), coords=[range(5)], dims=["x"]) + expected = DataArray(range(5)[2:], coords=[range(5)[2:]], dims=["x"]) actual = array.where(array > 1, drop=True) assert_identical(expected, actual) # dataset case - ds = Dataset({'a': array}) - expected = Dataset({'a': expected}) + ds = Dataset({"a": array}) + expected = Dataset({"a": expected}) actual = ds.where(ds > 1, drop=True) assert_identical(expected, actual) @@ -3813,60 +4105,73 @@ def test_where_drop(self): actual = ds.where(ds.a > 1, drop=True) assert_identical(expected, actual) - with raises_regex(TypeError, 'must be a'): + with raises_regex(TypeError, "must be a"): ds.where(np.arange(5) > 1, drop=True) # 1d with odd coordinates - array = DataArray(np.array([2, 7, 1, 8, 3]), - coords=[np.array([3, 1, 4, 5, 9])], dims=['x']) - expected = DataArray(np.array([7, 8, 3]), coords=[np.array([1, 5, 9])], - dims=['x']) + array = DataArray( + np.array([2, 7, 1, 8, 3]), coords=[np.array([3, 1, 4, 5, 9])], dims=["x"] + ) + expected = DataArray( + np.array([7, 8, 3]), coords=[np.array([1, 5, 9])], dims=["x"] + ) actual = array.where(array > 2, drop=True) assert_identical(expected, actual) # 1d multiple variables - ds = Dataset({'a': (('x'), [0, 1, 2, 3]), 'b': (('x'), [4, 5, 6, 7])}) - expected = Dataset({'a': (('x'), [np.nan, 1, 2, 3]), - 'b': (('x'), [4, 5, 6, np.nan])}) + ds = Dataset({"a": (("x"), [0, 1, 2, 3]), "b": (("x"), [4, 5, 6, 7])}) + expected = Dataset( + {"a": (("x"), [np.nan, 1, 2, 3]), "b": (("x"), [4, 5, 6, np.nan])} + ) actual = ds.where((ds > 0) & (ds < 7), drop=True) assert_identical(expected, actual) # 2d - ds = Dataset({'a': (('x', 'y'), [[0, 1], [2, 3]])}) - expected = Dataset({'a': (('x', 'y'), [[np.nan, 1], [2, 3]])}) + ds = Dataset({"a": (("x", "y"), [[0, 1], [2, 3]])}) + expected = Dataset({"a": (("x", "y"), [[np.nan, 1], [2, 3]])}) actual = ds.where(ds > 0, drop=True) assert_identical(expected, actual) # 2d with odd coordinates - ds = Dataset({'a': (('x', 'y'), [[0, 1], [2, 3]])}, coords={ - 'x': [4, 3], 'y': [1, 2], - 'z': (['x', 'y'], [[np.e, np.pi], [np.pi * np.e, np.pi * 3]])}) - expected = Dataset({'a': (('x', 'y'), [[3]])}, - coords={'x': [3], 'y': [2], - 'z': (['x', 'y'], [[np.pi * 3]])}) + ds = Dataset( + {"a": (("x", "y"), [[0, 1], [2, 3]])}, + coords={ + "x": [4, 3], + "y": [1, 2], + "z": (["x", "y"], [[np.e, np.pi], [np.pi * np.e, np.pi * 3]]), + }, + ) + expected = Dataset( + {"a": (("x", "y"), [[3]])}, + coords={"x": [3], "y": [2], "z": (["x", "y"], [[np.pi * 3]])}, + ) actual = ds.where(ds > 2, drop=True) assert_identical(expected, actual) # 2d multiple variables - ds = Dataset({'a': (('x', 'y'), [[0, 1], [2, 3]]), - 'b': (('x', 'y'), [[4, 5], [6, 7]])}) - expected = Dataset({'a': (('x', 'y'), [[np.nan, 1], [2, 3]]), - 'b': (('x', 'y'), [[4, 5], [6, 7]])}) + ds = Dataset( + {"a": (("x", "y"), [[0, 1], [2, 3]]), "b": (("x", "y"), [[4, 5], [6, 7]])} + ) + expected = Dataset( + { + "a": (("x", "y"), [[np.nan, 1], [2, 3]]), + "b": (("x", "y"), [[4, 5], [6, 7]]), + } + ) actual = ds.where(ds > 0, drop=True) assert_identical(expected, actual) def test_where_drop_empty(self): # regression test for GH1341 - array = DataArray(np.random.rand(100, 10), - dims=['nCells', 'nVertLevels']) - mask = DataArray(np.zeros((100,), dtype='bool'), dims='nCells') + array = DataArray(np.random.rand(100, 10), dims=["nCells", "nVertLevels"]) + mask = DataArray(np.zeros((100,), dtype="bool"), dims="nCells") actual = array.where(mask, drop=True) - expected = DataArray(np.zeros((0, 10)), dims=['nCells', 'nVertLevels']) + expected = DataArray(np.zeros((0, 10)), dims=["nCells", "nVertLevels"]) assert_identical(expected, actual) def test_where_drop_no_indexes(self): - ds = Dataset({'foo': ('x', [0.0, 1.0])}) - expected = Dataset({'foo': ('x', [1.0])}) + ds = Dataset({"foo": ("x", [0.0, 1.0])}) + expected = Dataset({"foo": ("x", [1.0])}) actual = ds.where(ds == 1, drop=True) assert_identical(expected, actual) @@ -3876,18 +4181,17 @@ def test_reduce(self): assert len(data.mean().coords) == 0 actual = data.max() - expected = Dataset( - {k: v.max() for k, v in data.data_vars.items()} - ) + expected = Dataset({k: v.max() for k, v in data.data_vars.items()}) assert_equal(expected, actual) - assert_equal(data.min(dim=['dim1']), - data.min(dim='dim1')) + assert_equal(data.min(dim=["dim1"]), data.min(dim="dim1")) - for reduct, expected in [('dim2', ['dim1', 'dim3', 'time']), - (['dim2', 'time'], ['dim1', 'dim3']), - (('dim2', 'time'), ['dim1', 'dim3']), - ((), ['dim1', 'dim2', 'dim3', 'time'])]: + for reduct, expected in [ + ("dim2", ["dim1", "dim3", "time"]), + (["dim2", "time"], ["dim1", "dim3"]), + (("dim2", "time"), ["dim1", "dim3"]), + ((), ["dim1", "dim2", "dim3", "time"]), + ]: actual = list(data.min(dim=reduct).dims) assert actual == expected @@ -3895,52 +4199,56 @@ def test_reduce(self): def test_reduce_coords(self): # regression test for GH1470 - data = xr.Dataset({'a': ('x', [1, 2, 3])}, coords={'b': 4}) - expected = xr.Dataset({'a': 2}, coords={'b': 4}) - actual = data.mean('x') + data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4}) + expected = xr.Dataset({"a": 2}, coords={"b": 4}) + actual = data.mean("x") assert_identical(actual, expected) # should be consistent - actual = data['a'].mean('x').to_dataset() + actual = data["a"].mean("x").to_dataset() assert_identical(actual, expected) def test_mean_uint_dtype(self): - data = xr.Dataset({'a': (('x', 'y'), - np.arange(6).reshape(3, 2).astype('uint')), - 'b': (('x', ), np.array([0.1, 0.2, np.nan]))}) - actual = data.mean('x', skipna=True) - expected = xr.Dataset({'a': data['a'].mean('x'), - 'b': data['b'].mean('x', skipna=True)}) + data = xr.Dataset( + { + "a": (("x", "y"), np.arange(6).reshape(3, 2).astype("uint")), + "b": (("x",), np.array([0.1, 0.2, np.nan])), + } + ) + actual = data.mean("x", skipna=True) + expected = xr.Dataset( + {"a": data["a"].mean("x"), "b": data["b"].mean("x", skipna=True)} + ) assert_identical(actual, expected) def test_reduce_bad_dim(self): data = create_test_data() - with raises_regex(ValueError, 'Dataset does not contain'): - data.mean(dim='bad_dim') + with raises_regex(ValueError, "Dataset does not contain"): + data.mean(dim="bad_dim") def test_reduce_cumsum(self): - data = xr.Dataset({'a': 1, - 'b': ('x', [1, 2]), - 'c': (('x', 'y'), [[np.nan, 3], [0, 4]])}) - assert_identical(data.fillna(0), data.cumsum('y')) - - expected = xr.Dataset({'a': 1, - 'b': ('x', [1, 3]), - 'c': (('x', 'y'), [[0, 3], [0, 7]])}) + data = xr.Dataset( + {"a": 1, "b": ("x", [1, 2]), "c": (("x", "y"), [[np.nan, 3], [0, 4]])} + ) + assert_identical(data.fillna(0), data.cumsum("y")) + + expected = xr.Dataset( + {"a": 1, "b": ("x", [1, 3]), "c": (("x", "y"), [[0, 3], [0, 7]])} + ) assert_identical(expected, data.cumsum()) def test_reduce_cumsum_test_dims(self): data = create_test_data() - for cumfunc in ['cumsum', 'cumprod']: - with raises_regex(ValueError, 'Dataset does not contain'): - getattr(data, cumfunc)(dim='bad_dim') + for cumfunc in ["cumsum", "cumprod"]: + with raises_regex(ValueError, "Dataset does not contain"): + getattr(data, cumfunc)(dim="bad_dim") # ensure dimensions are correct for reduct, expected in [ - ('dim1', ['dim1', 'dim2', 'dim3', 'time']), - ('dim2', ['dim1', 'dim2', 'dim3', 'time']), - ('dim3', ['dim1', 'dim2', 'dim3', 'time']), - ('time', ['dim1', 'dim2', 'dim3']) + ("dim1", ["dim1", "dim2", "dim3", "time"]), + ("dim2", ["dim1", "dim2", "dim3", "time"]), + ("dim3", ["dim1", "dim2", "dim3", "time"]), + ("time", ["dim1", "dim2", "dim3"]), ]: actual = getattr(data, cumfunc)(dim=reduct).dims assert list(actual) == expected @@ -3948,63 +4256,62 @@ def test_reduce_cumsum_test_dims(self): def test_reduce_non_numeric(self): data1 = create_test_data(seed=44) data2 = create_test_data(seed=44) - add_vars = {'var4': ['dim1', 'dim2']} + add_vars = {"var4": ["dim1", "dim2"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.dims[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) - data1[v] = (dims, data, {'foo': 'variable'}) + data1[v] = (dims, data, {"foo": "variable"}) - assert 'var4' not in data1.mean() + assert "var4" not in data1.mean() assert_equal(data1.mean(), data2.mean()) - assert_equal(data1.mean(dim='dim1'), - data2.mean(dim='dim1')) + assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) def test_reduce_strings(self): - expected = Dataset({'x': 'a'}) - ds = Dataset({'x': ('y', ['a', 'b'])}) + expected = Dataset({"x": "a"}) + ds = Dataset({"x": ("y", ["a", "b"])}) actual = ds.min() assert_identical(expected, actual) - expected = Dataset({'x': 'b'}) + expected = Dataset({"x": "b"}) actual = ds.max() assert_identical(expected, actual) - expected = Dataset({'x': 0}) + expected = Dataset({"x": 0}) actual = ds.argmin() assert_identical(expected, actual) - expected = Dataset({'x': 1}) + expected = Dataset({"x": 1}) actual = ds.argmax() assert_identical(expected, actual) - expected = Dataset({'x': b'a'}) - ds = Dataset({'x': ('y', np.array(['a', 'b'], 'S1'))}) + expected = Dataset({"x": b"a"}) + ds = Dataset({"x": ("y", np.array(["a", "b"], "S1"))}) actual = ds.min() assert_identical(expected, actual) - expected = Dataset({'x': 'a'}) - ds = Dataset({'x': ('y', np.array(['a', 'b'], 'U1'))}) + expected = Dataset({"x": "a"}) + ds = Dataset({"x": ("y", np.array(["a", "b"], "U1"))}) actual = ds.min() assert_identical(expected, actual) def test_reduce_dtypes(self): # regression test for GH342 - expected = Dataset({'x': 1}) - actual = Dataset({'x': True}).sum() + expected = Dataset({"x": 1}) + actual = Dataset({"x": True}).sum() assert_identical(expected, actual) # regression test for GH505 - expected = Dataset({'x': 3}) - actual = Dataset({'x': ('y', np.array([1, 2], 'uint16'))}).sum() + expected = Dataset({"x": 3}) + actual = Dataset({"x": ("y", np.array([1, 2], "uint16"))}).sum() assert_identical(expected, actual) - expected = Dataset({'x': 1 + 1j}) - actual = Dataset({'x': ('y', [1, 1j])}).sum() + expected = Dataset({"x": 1 + 1j}) + actual = Dataset({"x": ("y", [1, 1j])}).sum() assert_identical(expected, actual) def test_reduce_keep_attrs(self): data = create_test_data() - _attrs = {'attr1': 'value1', 'attr2': 2929} + _attrs = {"attr1": "value1", "attr2": 2929} attrs = OrderedDict(_attrs) data.attrs = attrs @@ -4023,50 +4330,49 @@ def test_reduce_keep_attrs(self): def test_reduce_argmin(self): # regression test for #205 - ds = Dataset({'a': ('x', [0, 1])}) - expected = Dataset({'a': ([], 0)}) + ds = Dataset({"a": ("x", [0, 1])}) + expected = Dataset({"a": ([], 0)}) actual = ds.argmin() assert_identical(expected, actual) - actual = ds.argmin('x') + actual = ds.argmin("x") assert_identical(expected, actual) def test_reduce_scalars(self): - ds = Dataset({'x': ('a', [2, 2]), 'y': 2, 'z': ('b', [2])}) - expected = Dataset({'x': 0, 'y': 0, 'z': 0}) + ds = Dataset({"x": ("a", [2, 2]), "y": 2, "z": ("b", [2])}) + expected = Dataset({"x": 0, "y": 0, "z": 0}) actual = ds.var() assert_identical(expected, actual) - expected = Dataset({'x': 0, 'y': 0, 'z': ('b', [0])}) - actual = ds.var('a') + expected = Dataset({"x": 0, "y": 0, "z": ("b", [0])}) + actual = ds.var("a") assert_identical(expected, actual) def test_reduce_only_one_axis(self): - def mean_only_one_axis(x, axis): if not isinstance(axis, integer_types): - raise TypeError('non-integer axis') + raise TypeError("non-integer axis") return x.mean(axis) - ds = Dataset({'a': (['x', 'y'], [[0, 1, 2, 3, 4]])}) - expected = Dataset({'a': ('x', [2])}) - actual = ds.reduce(mean_only_one_axis, 'y') + ds = Dataset({"a": (["x", "y"], [[0, 1, 2, 3, 4]])}) + expected = Dataset({"a": ("x", [2])}) + actual = ds.reduce(mean_only_one_axis, "y") assert_identical(expected, actual) - with raises_regex(TypeError, "missing 1 required positional argument: " - "'axis'"): + with raises_regex( + TypeError, "missing 1 required positional argument: " "'axis'" + ): ds.reduce(mean_only_one_axis) - with raises_regex(TypeError, 'non-integer axis'): - ds.reduce(mean_only_one_axis, axis=['x', 'y']) + with raises_regex(TypeError, "non-integer axis"): + ds.reduce(mean_only_one_axis, axis=["x", "y"]) def test_reduce_no_axis(self): - def total_sum(x): return np.sum(x.flatten()) - ds = Dataset({'a': (['x', 'y'], [[0, 1, 2, 3, 4]])}) - expected = Dataset({'a': ((), 10)}) + ds = Dataset({"a": (["x", "y"], [[0, 1, 2, 3, 4]])}) + expected = Dataset({"a": ((), 10)}) actual = ds.reduce(total_sum) assert_identical(expected, actual) @@ -4074,25 +4380,32 @@ def total_sum(x): ds.reduce(total_sum, axis=0) with raises_regex(TypeError, "unexpected keyword argument 'axis'"): - ds.reduce(total_sum, dim='x') + ds.reduce(total_sum, dim="x") def test_reduce_keepdims(self): - ds = Dataset({'a': (['x', 'y'], [[0, 1, 2, 3, 4]])}, - coords={'y': [0, 1, 2, 3, 4], 'x': [0], - 'lat': (['x', 'y'], [[0, 1, 2, 3, 4]]), - 'c': -999.0}) + ds = Dataset( + {"a": (["x", "y"], [[0, 1, 2, 3, 4]])}, + coords={ + "y": [0, 1, 2, 3, 4], + "x": [0], + "lat": (["x", "y"], [[0, 1, 2, 3, 4]]), + "c": -999.0, + }, + ) # Shape should match behaviour of numpy reductions with keepdims=True # Coordinates involved in the reduction should be removed actual = ds.mean(keepdims=True) - expected = Dataset({'a': (['x', 'y'], np.mean(ds.a, keepdims=True))}, - coords={'c': ds.c}) + expected = Dataset( + {"a": (["x", "y"], np.mean(ds.a, keepdims=True))}, coords={"c": ds.c} + ) assert_identical(expected, actual) - actual = ds.mean('x', keepdims=True) - expected = Dataset({'a': (['x', 'y'], - np.mean(ds.a, axis=0, keepdims=True))}, - coords={'y': ds.y, 'c': ds.c}) + actual = ds.mean("x", keepdims=True) + expected = Dataset( + {"a": (["x", "y"], np.mean(ds.a, axis=0, keepdims=True))}, + coords={"y": ds.y, "c": ds.c}, + ) assert_identical(expected, actual) def test_quantile(self): @@ -4100,44 +4413,43 @@ def test_quantile(self): ds = create_test_data(seed=123) for q in [0.25, [0.50], [0.25, 0.75]]: - for dim in [None, 'dim1', ['dim1']]: + for dim in [None, "dim1", ["dim1"]]: ds_quantile = ds.quantile(q, dim=dim) - assert 'quantile' in ds_quantile + assert "quantile" in ds_quantile for var, dar in ds.data_vars.items(): assert var in ds_quantile - assert_identical( - ds_quantile[var], dar.quantile(q, dim=dim)) - dim = ['dim1', 'dim2'] + assert_identical(ds_quantile[var], dar.quantile(q, dim=dim)) + dim = ["dim1", "dim2"] ds_quantile = ds.quantile(q, dim=dim) - assert 'dim3' in ds_quantile.dims + assert "dim3" in ds_quantile.dims assert all(d not in ds_quantile.dims for d in dim) @requires_bottleneck def test_rank(self): ds = create_test_data(seed=1234) # only ds.var3 depends on dim3 - z = ds.rank('dim3') - assert ['var3'] == list(z.data_vars) + z = ds.rank("dim3") + assert ["var3"] == list(z.data_vars) # same as dataarray version x = z.var3 - y = ds.var3.rank('dim3') + y = ds.var3.rank("dim3") assert_equal(x, y) # coordinates stick assert list(z.coords) == list(ds.coords) assert list(x.coords) == list(y.coords) # invalid dim - with raises_regex(ValueError, 'does not contain'): - x.rank('invalid_dim') + with raises_regex(ValueError, "does not contain"): + x.rank("invalid_dim") def test_count(self): - ds = Dataset({'x': ('a', [np.nan, 1]), 'y': 0, 'z': np.nan}) - expected = Dataset({'x': 1, 'y': 1, 'z': 0}) + ds = Dataset({"x": ("a", [np.nan, 1]), "y": 0, "z": np.nan}) + expected = Dataset({"x": 1, "y": 1, "z": 0}) actual = ds.count() assert_identical(expected, actual) def test_apply(self): data = create_test_data() - data.attrs['foo'] = 'bar' + data.attrs["foo"] = "bar" assert_identical(data.apply(np.mean), data.mean()) @@ -4145,28 +4457,29 @@ def test_apply(self): actual = data.apply(lambda x: x.mean(keep_attrs=True), keep_attrs=True) assert_identical(expected, actual) - assert_identical(data.apply(lambda x: x, keep_attrs=True), - data.drop('time')) + assert_identical(data.apply(lambda x: x, keep_attrs=True), data.drop("time")) def scale(x, multiple=1): return multiple * x actual = data.apply(scale, multiple=2) - assert_equal(actual['var1'], 2 * data['var1']) - assert_identical(actual['numbers'], data['numbers']) + assert_equal(actual["var1"], 2 * data["var1"]) + assert_identical(actual["numbers"], data["numbers"]) actual = data.apply(np.asarray) - expected = data.drop('time') # time is not used on a data var + expected = data.drop("time") # time is not used on a data var assert_equal(expected, actual) def make_example_math_dataset(self): variables = OrderedDict( - [('bar', ('x', np.arange(100, 400, 100))), - ('foo', (('x', 'y'), 1.0 * np.arange(12).reshape(3, 4)))]) - coords = {'abc': ('x', ['a', 'b', 'c']), - 'y': 10 * np.arange(4)} + [ + ("bar", ("x", np.arange(100, 400, 100))), + ("foo", (("x", "y"), 1.0 * np.arange(12).reshape(3, 4))), + ] + ) + coords = {"abc": ("x", ["a", "b", "c"]), "y": 10 * np.arange(4)} ds = Dataset(variables, coords) - ds['foo'][0, 0] = np.nan + ds["foo"][0, 0] = np.nan return ds def test_dataset_number_math(self): @@ -4188,9 +4501,11 @@ def test_unary_ops(self): assert_identical(ds.apply(abs), abs(ds)) assert_identical(ds.apply(lambda x: x + 4), ds + 4) - for func in [lambda x: x.isnull(), - lambda x: x.round(), - lambda x: x.astype(int)]: + for func in [ + lambda x: x.isnull(), + lambda x: x.round(), + lambda x: x.astype(int), + ]: assert_identical(ds.apply(func), func(ds)) assert_identical(ds.isnull(), ~ds.notnull()) @@ -4204,30 +4519,30 @@ def test_unary_ops(self): def test_dataset_array_math(self): ds = self.make_example_math_dataset() - expected = ds.apply(lambda x: x - ds['foo']) - assert_identical(expected, ds - ds['foo']) - assert_identical(expected, -ds['foo'] + ds) - assert_identical(expected, ds - ds['foo'].variable) - assert_identical(expected, -ds['foo'].variable + ds) + expected = ds.apply(lambda x: x - ds["foo"]) + assert_identical(expected, ds - ds["foo"]) + assert_identical(expected, -ds["foo"] + ds) + assert_identical(expected, ds - ds["foo"].variable) + assert_identical(expected, -ds["foo"].variable + ds) actual = ds.copy(deep=True) - actual -= ds['foo'] + actual -= ds["foo"] assert_identical(expected, actual) - expected = ds.apply(lambda x: x + ds['bar']) - assert_identical(expected, ds + ds['bar']) + expected = ds.apply(lambda x: x + ds["bar"]) + assert_identical(expected, ds + ds["bar"]) actual = ds.copy(deep=True) - actual += ds['bar'] + actual += ds["bar"] assert_identical(expected, actual) - expected = Dataset({'bar': ds['bar'] + np.arange(3)}) - assert_identical(expected, ds[['bar']] + np.arange(3)) - assert_identical(expected, np.arange(3) + ds[['bar']]) + expected = Dataset({"bar": ds["bar"] + np.arange(3)}) + assert_identical(expected, ds[["bar"]] + np.arange(3)) + assert_identical(expected, np.arange(3) + ds[["bar"]]) def test_dataset_dataset_math(self): ds = self.make_example_math_dataset() assert_identical(ds, ds + 0 * ds) - assert_identical(ds, ds + {'foo': 0, 'bar': 0}) + assert_identical(ds, ds + {"foo": 0, "bar": 0}) expected = ds.apply(lambda x: 2 * x) assert_identical(expected, 2 * ds) @@ -4256,17 +4571,17 @@ def test_dataset_math_auto_align(self): assert_identical(expected, actual) actual = ds.isel(y=slice(1)) + ds.isel(y=slice(1, None)) - expected = 2 * ds.drop(ds.y, dim='y') + expected = 2 * ds.drop(ds.y, dim="y") assert_equal(actual, expected) - actual = ds + ds[['bar']] - expected = (2 * ds[['bar']]).merge(ds.coords) + actual = ds + ds[["bar"]] + expected = (2 * ds[["bar"]]).merge(ds.coords) assert_identical(expected, actual) assert_identical(ds + Dataset(), ds.coords.to_dataset()) assert_identical(Dataset() + Dataset(), Dataset()) - ds2 = Dataset(coords={'bar': 42}) + ds2 = Dataset(coords={"bar": 42}) assert_identical(ds + ds2, ds.coords.merge(ds2)) # maybe unary arithmetic with empty datasets should raise instead? @@ -4282,37 +4597,47 @@ def test_dataset_math_errors(self): ds = self.make_example_math_dataset() with pytest.raises(TypeError): - ds['foo'] += ds + ds["foo"] += ds with pytest.raises(TypeError): - ds['foo'].variable += ds - with raises_regex(ValueError, 'must have the same'): - ds += ds[['bar']] + ds["foo"].variable += ds + with raises_regex(ValueError, "must have the same"): + ds += ds[["bar"]] # verify we can rollback in-place operations if something goes wrong # nb. inplace datetime64 math actually will work with an integer array # but not floats thanks to numpy's inconsistent handling - other = DataArray(np.datetime64('2000-01-01'), coords={'c': 2}) + other = DataArray(np.datetime64("2000-01-01"), coords={"c": 2}) actual = ds.copy(deep=True) with pytest.raises(TypeError): actual += other assert_identical(actual, ds) def test_dataset_transpose(self): - ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)), - 'b': (('y', 'x'), np.random.randn(4, 3))}, - coords={'x': range(3), 'y': range(4), - 'xy': (('x', 'y'), np.random.randn(3, 4))}) + ds = Dataset( + { + "a": (("x", "y"), np.random.randn(3, 4)), + "b": (("y", "x"), np.random.randn(4, 3)), + }, + coords={ + "x": range(3), + "y": range(4), + "xy": (("x", "y"), np.random.randn(3, 4)), + }, + ) actual = ds.transpose() - expected = Dataset({'a': (('y', 'x'), ds.a.values.T), - 'b': (('x', 'y'), ds.b.values.T)}, - coords={'x': ds.x.values, 'y': ds.y.values, - 'xy': (('y', 'x'), ds.xy.values.T)}) + expected = Dataset( + {"a": (("y", "x"), ds.a.values.T), "b": (("x", "y"), ds.b.values.T)}, + coords={ + "x": ds.x.values, + "y": ds.y.values, + "xy": (("y", "x"), ds.xy.values.T), + }, + ) assert_identical(expected, actual) - actual = ds.transpose('x', 'y') - expected = ds.apply( - lambda x: x.transpose('x', 'y', transpose_coords=True)) + actual = ds.transpose("x", "y") + expected = ds.apply(lambda x: x.transpose("x", "y", transpose_coords=True)) assert_identical(expected, actual) ds = create_test_data() @@ -4320,320 +4645,346 @@ def test_dataset_transpose(self): for k in ds.variables: assert actual[k].dims[::-1] == ds[k].dims - new_order = ('dim2', 'dim3', 'dim1', 'time') + new_order = ("dim2", "dim3", "dim1", "time") actual = ds.transpose(*new_order) for k in ds.variables: expected_dims = tuple(d for d in new_order if d in ds[k].dims) assert actual[k].dims == expected_dims - with raises_regex(ValueError, 'arguments to transpose'): - ds.transpose('dim1', 'dim2', 'dim3') - with raises_regex(ValueError, 'arguments to transpose'): - ds.transpose('dim1', 'dim2', 'dim3', 'time', 'extra_dim') + with raises_regex(ValueError, "arguments to transpose"): + ds.transpose("dim1", "dim2", "dim3") + with raises_regex(ValueError, "arguments to transpose"): + ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim") - assert 'T' not in dir(ds) + assert "T" not in dir(ds) def test_dataset_retains_period_index_on_transpose(self): ds = create_test_data() - ds['time'] = pd.period_range('2000-01-01', periods=20) + ds["time"] = pd.period_range("2000-01-01", periods=20) transposed = ds.transpose() assert isinstance(transposed.time.to_index(), pd.PeriodIndex) def test_dataset_diff_n1_simple(self): - ds = Dataset({'foo': ('x', [5, 5, 6, 6])}) - actual = ds.diff('x') - expected = Dataset({'foo': ('x', [0, 1, 0])}) + ds = Dataset({"foo": ("x", [5, 5, 6, 6])}) + actual = ds.diff("x") + expected = Dataset({"foo": ("x", [0, 1, 0])}) assert_equal(expected, actual) def test_dataset_diff_n1_label(self): - ds = Dataset({'foo': ('x', [5, 5, 6, 6])}, {'x': [0, 1, 2, 3]}) - actual = ds.diff('x', label='lower') - expected = Dataset({'foo': ('x', [0, 1, 0])}, {'x': [0, 1, 2]}) + ds = Dataset({"foo": ("x", [5, 5, 6, 6])}, {"x": [0, 1, 2, 3]}) + actual = ds.diff("x", label="lower") + expected = Dataset({"foo": ("x", [0, 1, 0])}, {"x": [0, 1, 2]}) assert_equal(expected, actual) - actual = ds.diff('x', label='upper') - expected = Dataset({'foo': ('x', [0, 1, 0])}, {'x': [1, 2, 3]}) + actual = ds.diff("x", label="upper") + expected = Dataset({"foo": ("x", [0, 1, 0])}, {"x": [1, 2, 3]}) assert_equal(expected, actual) def test_dataset_diff_n1(self): ds = create_test_data(seed=1) - actual = ds.diff('dim2') + actual = ds.diff("dim2") expected = dict() - expected['var1'] = DataArray(np.diff(ds['var1'].values, axis=1), - {'dim2': ds['dim2'].values[1:]}, - ['dim1', 'dim2']) - expected['var2'] = DataArray(np.diff(ds['var2'].values, axis=1), - {'dim2': ds['dim2'].values[1:]}, - ['dim1', 'dim2']) - expected['var3'] = ds['var3'] - expected = Dataset(expected, coords={'time': ds['time'].values}) - expected.coords['numbers'] = ('dim3', ds['numbers'].values) + expected["var1"] = DataArray( + np.diff(ds["var1"].values, axis=1), + {"dim2": ds["dim2"].values[1:]}, + ["dim1", "dim2"], + ) + expected["var2"] = DataArray( + np.diff(ds["var2"].values, axis=1), + {"dim2": ds["dim2"].values[1:]}, + ["dim1", "dim2"], + ) + expected["var3"] = ds["var3"] + expected = Dataset(expected, coords={"time": ds["time"].values}) + expected.coords["numbers"] = ("dim3", ds["numbers"].values) assert_equal(expected, actual) def test_dataset_diff_n2(self): ds = create_test_data(seed=1) - actual = ds.diff('dim2', n=2) + actual = ds.diff("dim2", n=2) expected = dict() - expected['var1'] = DataArray(np.diff(ds['var1'].values, axis=1, n=2), - {'dim2': ds['dim2'].values[2:]}, - ['dim1', 'dim2']) - expected['var2'] = DataArray(np.diff(ds['var2'].values, axis=1, n=2), - {'dim2': ds['dim2'].values[2:]}, - ['dim1', 'dim2']) - expected['var3'] = ds['var3'] - expected = Dataset(expected, coords={'time': ds['time'].values}) - expected.coords['numbers'] = ('dim3', ds['numbers'].values) + expected["var1"] = DataArray( + np.diff(ds["var1"].values, axis=1, n=2), + {"dim2": ds["dim2"].values[2:]}, + ["dim1", "dim2"], + ) + expected["var2"] = DataArray( + np.diff(ds["var2"].values, axis=1, n=2), + {"dim2": ds["dim2"].values[2:]}, + ["dim1", "dim2"], + ) + expected["var3"] = ds["var3"] + expected = Dataset(expected, coords={"time": ds["time"].values}) + expected.coords["numbers"] = ("dim3", ds["numbers"].values) assert_equal(expected, actual) def test_dataset_diff_exception_n_neg(self): ds = create_test_data(seed=1) - with raises_regex(ValueError, 'must be non-negative'): - ds.diff('dim2', n=-1) + with raises_regex(ValueError, "must be non-negative"): + ds.diff("dim2", n=-1) def test_dataset_diff_exception_label_str(self): ds = create_test_data(seed=1) - with raises_regex(ValueError, '\'label\' argument has to'): - ds.diff('dim2', label='raise_me') + with raises_regex(ValueError, "'label' argument has to"): + ds.diff("dim2", label="raise_me") - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_shift(self, fill_value): - coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} - attrs = {'meta': 'data'} - ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) actual = ds.shift(x=1, fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = Dataset({'foo': ('x', [fill_value, 1, 2])}, coords, attrs) + expected = Dataset({"foo": ("x", [fill_value, 1, 2])}, coords, attrs) assert_identical(expected, actual) - with raises_regex(ValueError, 'dimensions'): + with raises_regex(ValueError, "dimensions"): ds.shift(foo=123) def test_roll_coords(self): - coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} - attrs = {'meta': 'data'} - ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) actual = ds.roll(x=1, roll_coords=True) - ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} - expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) + ex_coords = {"bar": ("x", list("cab")), "x": [2, -4, 3]} + expected = Dataset({"foo": ("x", [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) - with raises_regex(ValueError, 'dimensions'): + with raises_regex(ValueError, "dimensions"): ds.roll(foo=123, roll_coords=True) def test_roll_no_coords(self): - coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} - attrs = {'meta': 'data'} - ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) actual = ds.roll(x=1, roll_coords=False) - expected = Dataset({'foo': ('x', [3, 1, 2])}, coords, attrs) + expected = Dataset({"foo": ("x", [3, 1, 2])}, coords, attrs) assert_identical(expected, actual) - with raises_regex(ValueError, 'dimensions'): + with raises_regex(ValueError, "dimensions"): ds.roll(abc=321, roll_coords=False) def test_roll_coords_none(self): - coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} - attrs = {'meta': 'data'} - ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) with pytest.warns(FutureWarning): actual = ds.roll(x=1, roll_coords=None) - ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} - expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) + ex_coords = {"bar": ("x", list("cab")), "x": [2, -4, 3]} + expected = Dataset({"foo": ("x", [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) def test_roll_multidim(self): # regression test for 2445 arr = xr.DataArray( - [[1, 2, 3], [4, 5, 6]], coords={'x': range(3), 'y': range(2)}, - dims=('y', 'x')) + [[1, 2, 3], [4, 5, 6]], + coords={"x": range(3), "y": range(2)}, + dims=("y", "x"), + ) actual = arr.roll(x=1, roll_coords=True) - expected = xr.DataArray([[3, 1, 2], [6, 4, 5]], - coords=[('y', [0, 1]), ('x', [2, 0, 1])]) + expected = xr.DataArray( + [[3, 1, 2], [6, 4, 5]], coords=[("y", [0, 1]), ("x", [2, 0, 1])] + ) assert_identical(expected, actual) def test_real_and_imag(self): - attrs = {'foo': 'bar'} - ds = Dataset({'x': ((), 1 + 2j, attrs)}, attrs=attrs) + attrs = {"foo": "bar"} + ds = Dataset({"x": ((), 1 + 2j, attrs)}, attrs=attrs) - expected_re = Dataset({'x': ((), 1, attrs)}, attrs=attrs) + expected_re = Dataset({"x": ((), 1, attrs)}, attrs=attrs) assert_identical(ds.real, expected_re) - expected_im = Dataset({'x': ((), 2, attrs)}, attrs=attrs) + expected_im = Dataset({"x": ((), 2, attrs)}, attrs=attrs) assert_identical(ds.imag, expected_im) def test_setattr_raises(self): - ds = Dataset({}, coords={'scalar': 1}, attrs={'foo': 'bar'}) - with raises_regex(AttributeError, 'cannot set attr'): + ds = Dataset({}, coords={"scalar": 1}, attrs={"foo": "bar"}) + with raises_regex(AttributeError, "cannot set attr"): ds.scalar = 2 - with raises_regex(AttributeError, 'cannot set attr'): + with raises_regex(AttributeError, "cannot set attr"): ds.foo = 2 - with raises_regex(AttributeError, 'cannot set attr'): + with raises_regex(AttributeError, "cannot set attr"): ds.other = 2 def test_filter_by_attrs(self): - precip = dict(standard_name='convective_precipitation_flux') - temp0 = dict(standard_name='air_potential_temperature', height='0 m') - temp10 = dict(standard_name='air_potential_temperature', height='10 m') - ds = Dataset({'temperature_0': (['t'], [0], temp0), - 'temperature_10': (['t'], [0], temp10), - 'precipitation': (['t'], [0], precip)}, - coords={'time': (['t'], [0], dict(axis='T'))}) + precip = dict(standard_name="convective_precipitation_flux") + temp0 = dict(standard_name="air_potential_temperature", height="0 m") + temp10 = dict(standard_name="air_potential_temperature", height="10 m") + ds = Dataset( + { + "temperature_0": (["t"], [0], temp0), + "temperature_10": (["t"], [0], temp10), + "precipitation": (["t"], [0], precip), + }, + coords={"time": (["t"], [0], dict(axis="T"))}, + ) # Test return empty Dataset. - ds.filter_by_attrs(standard_name='invalid_standard_name') - new_ds = ds.filter_by_attrs(standard_name='invalid_standard_name') + ds.filter_by_attrs(standard_name="invalid_standard_name") + new_ds = ds.filter_by_attrs(standard_name="invalid_standard_name") assert not bool(new_ds.data_vars) # Test return one DataArray. - new_ds = ds.filter_by_attrs( - standard_name='convective_precipitation_flux') - assert (new_ds['precipitation'].standard_name == - 'convective_precipitation_flux') + new_ds = ds.filter_by_attrs(standard_name="convective_precipitation_flux") + assert new_ds["precipitation"].standard_name == "convective_precipitation_flux" - assert_equal(new_ds['precipitation'], ds['precipitation']) + assert_equal(new_ds["precipitation"], ds["precipitation"]) # Test return more than one DataArray. - new_ds = ds.filter_by_attrs(standard_name='air_potential_temperature') + new_ds = ds.filter_by_attrs(standard_name="air_potential_temperature") assert len(new_ds.data_vars) == 2 for var in new_ds.data_vars: - assert new_ds[var].standard_name == 'air_potential_temperature' + assert new_ds[var].standard_name == "air_potential_temperature" # Test callable. new_ds = ds.filter_by_attrs(height=lambda v: v is not None) assert len(new_ds.data_vars) == 2 for var in new_ds.data_vars: - assert new_ds[var].standard_name == 'air_potential_temperature' + assert new_ds[var].standard_name == "air_potential_temperature" - new_ds = ds.filter_by_attrs(height='10 m') + new_ds = ds.filter_by_attrs(height="10 m") assert len(new_ds.data_vars) == 1 for var in new_ds.data_vars: - assert new_ds[var].height == '10 m' + assert new_ds[var].height == "10 m" # Test return empty Dataset due to conflicting filters new_ds = ds.filter_by_attrs( - standard_name='convective_precipitation_flux', - height='0 m') + standard_name="convective_precipitation_flux", height="0 m" + ) assert not bool(new_ds.data_vars) # Test return one DataArray with two filter conditions new_ds = ds.filter_by_attrs( - standard_name='air_potential_temperature', - height='0 m') + standard_name="air_potential_temperature", height="0 m" + ) for var in new_ds.data_vars: - assert new_ds[var].standard_name == 'air_potential_temperature' - assert new_ds[var].height == '0 m' - assert new_ds[var].height != '10 m' + assert new_ds[var].standard_name == "air_potential_temperature" + assert new_ds[var].height == "0 m" + assert new_ds[var].height != "10 m" # Test return empty Dataset due to conflicting callables - new_ds = ds.filter_by_attrs(standard_name=lambda v: False, - height=lambda v: True) + new_ds = ds.filter_by_attrs( + standard_name=lambda v: False, height=lambda v: True + ) assert not bool(new_ds.data_vars) def test_binary_op_join_setting(self): # arithmetic_join applies to data array coordinates - missing_2 = xr.Dataset({'x': [0, 1]}) - missing_0 = xr.Dataset({'x': [1, 2]}) - with xr.set_options(arithmetic_join='outer'): + missing_2 = xr.Dataset({"x": [0, 1]}) + missing_0 = xr.Dataset({"x": [1, 2]}) + with xr.set_options(arithmetic_join="outer"): actual = missing_2 + missing_0 - expected = xr.Dataset({'x': [0, 1, 2]}) + expected = xr.Dataset({"x": [0, 1, 2]}) assert_equal(actual, expected) # arithmetic join also applies to data_vars - ds1 = xr.Dataset({'foo': 1, 'bar': 2}) - ds2 = xr.Dataset({'bar': 2, 'baz': 3}) - expected = xr.Dataset({'bar': 4}) # default is inner joining + ds1 = xr.Dataset({"foo": 1, "bar": 2}) + ds2 = xr.Dataset({"bar": 2, "baz": 3}) + expected = xr.Dataset({"bar": 4}) # default is inner joining actual = ds1 + ds2 assert_equal(actual, expected) - with xr.set_options(arithmetic_join='outer'): - expected = xr.Dataset({'foo': np.nan, 'bar': 4, 'baz': np.nan}) + with xr.set_options(arithmetic_join="outer"): + expected = xr.Dataset({"foo": np.nan, "bar": 4, "baz": np.nan}) actual = ds1 + ds2 assert_equal(actual, expected) - with xr.set_options(arithmetic_join='left'): - expected = xr.Dataset({'foo': np.nan, 'bar': 4}) + with xr.set_options(arithmetic_join="left"): + expected = xr.Dataset({"foo": np.nan, "bar": 4}) actual = ds1 + ds2 assert_equal(actual, expected) - with xr.set_options(arithmetic_join='right'): - expected = xr.Dataset({'bar': 4, 'baz': np.nan}) + with xr.set_options(arithmetic_join="right"): + expected = xr.Dataset({"bar": 4, "baz": np.nan}) actual = ds1 + ds2 assert_equal(actual, expected) def test_full_like(self): # For more thorough tests, see test_variable.py # Note: testing data_vars with mismatched dtypes - ds = Dataset({ - 'd1': DataArray([1, 2, 3], dims=['x'], coords={'x': [10, 20, 30]}), - 'd2': DataArray([1.1, 2.2, 3.3], dims=['y']) - }, attrs={'foo': 'bar'}) + ds = Dataset( + { + "d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]}), + "d2": DataArray([1.1, 2.2, 3.3], dims=["y"]), + }, + attrs={"foo": "bar"}, + ) actual = full_like(ds, 2) expect = ds.copy(deep=True) - expect['d1'].values = [2, 2, 2] - expect['d2'].values = [2.0, 2.0, 2.0] - assert expect['d1'].dtype == int - assert expect['d2'].dtype == float + expect["d1"].values = [2, 2, 2] + expect["d2"].values = [2.0, 2.0, 2.0] + assert expect["d1"].dtype == int + assert expect["d2"].dtype == float assert_identical(expect, actual) # override dtype actual = full_like(ds, fill_value=True, dtype=bool) expect = ds.copy(deep=True) - expect['d1'].values = [True, True, True] - expect['d2'].values = [True, True, True] - assert expect['d1'].dtype == bool - assert expect['d2'].dtype == bool + expect["d1"].values = [True, True, True] + expect["d2"].values = [True, True, True] + assert expect["d1"].dtype == bool + assert expect["d2"].dtype == bool assert_identical(expect, actual) def test_combine_first(self): - dsx0 = DataArray([0, 0], [('x', ['a', 'b'])]).to_dataset(name='dsx0') - dsx1 = DataArray([1, 1], [('x', ['b', 'c'])]).to_dataset(name='dsx1') + dsx0 = DataArray([0, 0], [("x", ["a", "b"])]).to_dataset(name="dsx0") + dsx1 = DataArray([1, 1], [("x", ["b", "c"])]).to_dataset(name="dsx1") actual = dsx0.combine_first(dsx1) - expected = Dataset({'dsx0': ('x', [0, 0, np.nan]), - 'dsx1': ('x', [np.nan, 1, 1])}, - coords={'x': ['a', 'b', 'c']}) + expected = Dataset( + {"dsx0": ("x", [0, 0, np.nan]), "dsx1": ("x", [np.nan, 1, 1])}, + coords={"x": ["a", "b", "c"]}, + ) assert_equal(actual, expected) assert_equal(actual, xr.merge([dsx0, dsx1])) # works just like xr.merge([self, other]) - dsy2 = DataArray([2, 2, 2], - [('x', ['b', 'c', 'd'])]).to_dataset(name='dsy2') + dsy2 = DataArray([2, 2, 2], [("x", ["b", "c", "d"])]).to_dataset(name="dsy2") actual = dsx0.combine_first(dsy2) expected = xr.merge([dsy2, dsx0]) assert_equal(actual, expected) def test_sortby(self): - ds = Dataset({'A': DataArray([[1, 2], [3, 4], [5, 6]], - [('x', ['c', 'b', 'a']), - ('y', [1, 0])]), - 'B': DataArray([[5, 6], [7, 8], [9, 10]], - dims=['x', 'y'])}) - - sorted1d = Dataset({'A': DataArray([[5, 6], [3, 4], [1, 2]], - [('x', ['a', 'b', 'c']), - ('y', [1, 0])]), - 'B': DataArray([[9, 10], [7, 8], [5, 6]], - dims=['x', 'y'])}) - - sorted2d = Dataset({'A': DataArray([[6, 5], [4, 3], [2, 1]], - [('x', ['a', 'b', 'c']), - ('y', [0, 1])]), - 'B': DataArray([[10, 9], [8, 7], [6, 5]], - dims=['x', 'y'])}) + ds = Dataset( + { + "A": DataArray( + [[1, 2], [3, 4], [5, 6]], [("x", ["c", "b", "a"]), ("y", [1, 0])] + ), + "B": DataArray([[5, 6], [7, 8], [9, 10]], dims=["x", "y"]), + } + ) + + sorted1d = Dataset( + { + "A": DataArray( + [[5, 6], [3, 4], [1, 2]], [("x", ["a", "b", "c"]), ("y", [1, 0])] + ), + "B": DataArray([[9, 10], [7, 8], [5, 6]], dims=["x", "y"]), + } + ) + + sorted2d = Dataset( + { + "A": DataArray( + [[6, 5], [4, 3], [2, 1]], [("x", ["a", "b", "c"]), ("y", [0, 1])] + ), + "B": DataArray([[10, 9], [8, 7], [6, 5]], dims=["x", "y"]), + } + ) expected = sorted1d - dax = DataArray([100, 99, 98], [('x', ['c', 'b', 'a'])]) + dax = DataArray([100, 99, 98], [("x", ["c", "b", "a"])]) actual = ds.sortby(dax) assert_equal(actual, expected) @@ -4642,121 +4993,138 @@ def test_sortby(self): assert_equal(actual, ds) # test alignment (fills in nan for 'c') - dax_short = DataArray([98, 97], [('x', ['b', 'a'])]) + dax_short = DataArray([98, 97], [("x", ["b", "a"])]) actual = ds.sortby(dax_short) assert_equal(actual, expected) # test 1-D lexsort # dax0 is sorted first to give indices of [1, 2, 0] # and then dax1 would be used to move index 2 ahead of 1 - dax0 = DataArray([100, 95, 95], [('x', ['c', 'b', 'a'])]) - dax1 = DataArray([0, 1, 0], [('x', ['c', 'b', 'a'])]) + dax0 = DataArray([100, 95, 95], [("x", ["c", "b", "a"])]) + dax1 = DataArray([0, 1, 0], [("x", ["c", "b", "a"])]) actual = ds.sortby([dax0, dax1]) # lexsort underneath gives [2, 1, 0] assert_equal(actual, expected) expected = sorted2d # test multi-dim sort by 1D dataarray values - day = DataArray([90, 80], [('y', [1, 0])]) + day = DataArray([90, 80], [("y", [1, 0])]) actual = ds.sortby([day, dax]) assert_equal(actual, expected) # test exception-raising with pytest.raises(KeyError) as excinfo: - actual = ds.sortby('z') + actual = ds.sortby("z") with pytest.raises(ValueError) as excinfo: - actual = ds.sortby(ds['A']) + actual = ds.sortby(ds["A"]) assert "DataArray is not 1-D" in str(excinfo.value) expected = sorted1d - actual = ds.sortby('x') + actual = ds.sortby("x") assert_equal(actual, expected) # test pandas.MultiIndex - indices = (('b', 1), ('b', 0), ('a', 1), ('a', 0)) - midx = pd.MultiIndex.from_tuples(indices, names=['one', 'two']) - ds_midx = Dataset({'A': DataArray([[1, 2], [3, 4], [5, 6], [7, 8]], - [('x', midx), ('y', [1, 0])]), - 'B': DataArray([[5, 6], [7, 8], [9, 10], [11, 12]], - dims=['x', 'y'])}) - actual = ds_midx.sortby('x') - midx_reversed = pd.MultiIndex.from_tuples(tuple(reversed(indices)), - names=['one', 'two']) - expected = Dataset({'A': DataArray([[7, 8], [5, 6], [3, 4], [1, 2]], - [('x', midx_reversed), - ('y', [1, 0])]), - 'B': DataArray([[11, 12], [9, 10], [7, 8], [5, 6]], - dims=['x', 'y'])}) + indices = (("b", 1), ("b", 0), ("a", 1), ("a", 0)) + midx = pd.MultiIndex.from_tuples(indices, names=["one", "two"]) + ds_midx = Dataset( + { + "A": DataArray( + [[1, 2], [3, 4], [5, 6], [7, 8]], [("x", midx), ("y", [1, 0])] + ), + "B": DataArray([[5, 6], [7, 8], [9, 10], [11, 12]], dims=["x", "y"]), + } + ) + actual = ds_midx.sortby("x") + midx_reversed = pd.MultiIndex.from_tuples( + tuple(reversed(indices)), names=["one", "two"] + ) + expected = Dataset( + { + "A": DataArray( + [[7, 8], [5, 6], [3, 4], [1, 2]], + [("x", midx_reversed), ("y", [1, 0])], + ), + "B": DataArray([[11, 12], [9, 10], [7, 8], [5, 6]], dims=["x", "y"]), + } + ) assert_equal(actual, expected) # multi-dim sort by coordinate objects expected = sorted2d - actual = ds.sortby(['x', 'y']) + actual = ds.sortby(["x", "y"]) assert_equal(actual, expected) # test descending order sort - actual = ds.sortby(['x', 'y'], ascending=False) + actual = ds.sortby(["x", "y"], ascending=False) assert_equal(actual, ds) def test_attribute_access(self): ds = create_test_data(seed=1) - for key in ['var1', 'var2', 'var3', 'time', 'dim1', - 'dim2', 'dim3', 'numbers']: + for key in ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"]: assert_equal(ds[key], getattr(ds, key)) assert key in dir(ds) - for key in ['dim3', 'dim1', 'numbers']: - assert_equal(ds['var3'][key], getattr(ds.var3, key)) - assert key in dir(ds['var3']) + for key in ["dim3", "dim1", "numbers"]: + assert_equal(ds["var3"][key], getattr(ds.var3, key)) + assert key in dir(ds["var3"]) # attrs - assert ds['var3'].attrs['foo'] == ds.var3.foo - assert 'foo' in dir(ds['var3']) + assert ds["var3"].attrs["foo"] == ds.var3.foo + assert "foo" in dir(ds["var3"]) def test_ipython_key_completion(self): ds = create_test_data(seed=1) actual = ds._ipython_key_completions_() - expected = ['var1', 'var2', 'var3', 'time', 'dim1', - 'dim2', 'dim3', 'numbers'] + expected = ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"] for item in actual: ds[item] # should not raise assert sorted(actual) == sorted(expected) # for dataarray - actual = ds['var3']._ipython_key_completions_() - expected = ['dim3', 'dim1', 'numbers'] + actual = ds["var3"]._ipython_key_completions_() + expected = ["dim3", "dim1", "numbers"] for item in actual: - ds['var3'][item] # should not raise + ds["var3"][item] # should not raise assert sorted(actual) == sorted(expected) # MultiIndex - ds_midx = ds.stack(dim12=['dim1', 'dim2']) + ds_midx = ds.stack(dim12=["dim1", "dim2"]) actual = ds_midx._ipython_key_completions_() - expected = ['var1', 'var2', 'var3', 'time', 'dim1', - 'dim2', 'dim3', 'numbers', 'dim12'] + expected = [ + "var1", + "var2", + "var3", + "time", + "dim1", + "dim2", + "dim3", + "numbers", + "dim12", + ] for item in actual: ds_midx[item] # should not raise assert sorted(actual) == sorted(expected) # coords actual = ds.coords._ipython_key_completions_() - expected = ['time', 'dim1', 'dim2', 'dim3', 'numbers'] + expected = ["time", "dim1", "dim2", "dim3", "numbers"] for item in actual: ds.coords[item] # should not raise assert sorted(actual) == sorted(expected) - actual = ds['var3'].coords._ipython_key_completions_() - expected = ['dim1', 'dim3', 'numbers'] + actual = ds["var3"].coords._ipython_key_completions_() + expected = ["dim1", "dim3", "numbers"] for item in actual: - ds['var3'].coords[item] # should not raise + ds["var3"].coords[item] # should not raise assert sorted(actual) == sorted(expected) # data_vars actual = ds.data_vars._ipython_key_completions_() - expected = ['var1', 'var2', 'var3', 'dim1'] + expected = ["var1", "var2", "var3", "dim1"] for item in actual: ds.data_vars[item] # should not raise assert sorted(actual) == sorted(expected) + # Py.test tests @@ -4765,88 +5133,84 @@ def data_set(request): return create_test_data(request.param) -@pytest.mark.parametrize('test_elements', ( - [1, 2], - np.array([1, 2]), - DataArray([1, 2]), -)) +@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) def test_isin(test_elements): expected = Dataset( data_vars={ - 'var1': (('dim1',), [0, 1]), - 'var2': (('dim1',), [1, 1]), - 'var3': (('dim1',), [0, 1]), + "var1": (("dim1",), [0, 1]), + "var2": (("dim1",), [1, 1]), + "var3": (("dim1",), [0, 1]), } - ).astype('bool') + ).astype("bool") result = Dataset( data_vars={ - 'var1': (('dim1',), [0, 1]), - 'var2': (('dim1',), [1, 2]), - 'var3': (('dim1',), [0, 1]), + "var1": (("dim1",), [0, 1]), + "var2": (("dim1",), [1, 2]), + "var3": (("dim1",), [0, 1]), } ).isin(test_elements) assert_equal(result, expected) -@pytest.mark.skipif(not has_dask, reason='requires dask') -@pytest.mark.parametrize('test_elements', ( - [1, 2], - np.array([1, 2]), - DataArray([1, 2]), -)) +@pytest.mark.skipif(not has_dask, reason="requires dask") +@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) def test_isin_dask(test_elements): expected = Dataset( data_vars={ - 'var1': (('dim1',), [0, 1]), - 'var2': (('dim1',), [1, 1]), - 'var3': (('dim1',), [0, 1]), - } - ).astype('bool') - - result = Dataset( - data_vars={ - 'var1': (('dim1',), [0, 1]), - 'var2': (('dim1',), [1, 2]), - 'var3': (('dim1',), [0, 1]), + "var1": (("dim1",), [0, 1]), + "var2": (("dim1",), [1, 1]), + "var3": (("dim1",), [0, 1]), } - ).chunk(1).isin(test_elements).compute() + ).astype("bool") + + result = ( + Dataset( + data_vars={ + "var1": (("dim1",), [0, 1]), + "var2": (("dim1",), [1, 2]), + "var3": (("dim1",), [0, 1]), + } + ) + .chunk(1) + .isin(test_elements) + .compute() + ) assert_equal(result, expected) def test_isin_dataset(): - ds = Dataset({'x': [1, 2]}) + ds = Dataset({"x": [1, 2]}) with pytest.raises(TypeError): ds.isin(ds) -@pytest.mark.parametrize('unaligned_coords', ( - {'x': [2, 1, 0]}, - {'x': (['x'], np.asarray([2, 1, 0]))}, - {'x': (['x'], np.asarray([1, 2, 0]))}, - {'x': pd.Index([2, 1, 0])}, - {'x': Variable(dims='x', data=[0, 2, 1])}, - {'x': IndexVariable(dims='x', data=[0, 1, 2])}, - {'y': 42}, - {'y': ('x', [2, 1, 0])}, - {'y': ('x', np.asarray([2, 1, 0]))}, - {'y': (['x'], np.asarray([2, 1, 0]))}, -)) -@pytest.mark.parametrize('coords', ( - {'x': ('x', [0, 1, 2])}, - {'x': [0, 1, 2]}, -)) -def test_dataset_constructor_aligns_to_explicit_coords( - unaligned_coords, coords): - - a = xr.DataArray([1, 2, 3], dims=['x'], coords=unaligned_coords) +@pytest.mark.parametrize( + "unaligned_coords", + ( + {"x": [2, 1, 0]}, + {"x": (["x"], np.asarray([2, 1, 0]))}, + {"x": (["x"], np.asarray([1, 2, 0]))}, + {"x": pd.Index([2, 1, 0])}, + {"x": Variable(dims="x", data=[0, 2, 1])}, + {"x": IndexVariable(dims="x", data=[0, 1, 2])}, + {"y": 42}, + {"y": ("x", [2, 1, 0])}, + {"y": ("x", np.asarray([2, 1, 0]))}, + {"y": (["x"], np.asarray([2, 1, 0]))}, + ), +) +@pytest.mark.parametrize("coords", ({"x": ("x", [0, 1, 2])}, {"x": [0, 1, 2]})) +def test_dataset_constructor_aligns_to_explicit_coords(unaligned_coords, coords): + + a = xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords) expected = xr.Dataset(coords=coords) - expected['a'] = a + expected["a"] = a - result = xr.Dataset({'a': a}, coords=coords) + result = xr.Dataset({"a": a}, coords=coords) assert_equal(expected, result) @@ -4856,27 +5220,23 @@ def test_error_message_on_set_supplied(): xr.Dataset(dict(date=[1, 2, 3], sec={4})) -@pytest.mark.parametrize('unaligned_coords', ( - {'y': ('b', np.asarray([2, 1, 0]))}, -)) +@pytest.mark.parametrize("unaligned_coords", ({"y": ("b", np.asarray([2, 1, 0]))},)) def test_constructor_raises_with_invalid_coords(unaligned_coords): - with pytest.raises(ValueError, - match='not a subset of the DataArray dimensions'): - xr.DataArray([1, 2, 3], dims=['x'], coords=unaligned_coords) + with pytest.raises(ValueError, match="not a subset of the DataArray dimensions"): + xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords) def test_dir_expected_attrs(data_set): - some_expected_attrs = {'pipe', 'mean', 'isnull', 'var1', - 'dim2', 'numbers'} + some_expected_attrs = {"pipe", "mean", "isnull", "var1", "dim2", "numbers"} result = dir(data_set) assert set(result) >= some_expected_attrs def test_dir_non_string(data_set): # add a numbered key to ensure this doesn't break dir - data_set[5] = 'foo' + data_set[5] = "foo" result = dir(data_set) assert 5 not in result @@ -4888,196 +5248,211 @@ def test_dir_non_string(data_set): def test_dir_unicode(data_set): - data_set['unicode'] = 'uni' + data_set["unicode"] = "uni" result = dir(data_set) - assert 'unicode' in result + assert "unicode" in result @pytest.fixture(params=[1]) def ds(request): if request.param == 1: - return Dataset({'z1': (['y', 'x'], np.random.randn(2, 8)), - 'z2': (['time', 'y'], np.random.randn(10, 2))}, - {'x': ('x', np.linspace(0, 1.0, 8)), - 'time': ('time', np.linspace(0, 1.0, 10)), - 'c': ('y', ['a', 'b']), - 'y': range(2)}) + return Dataset( + { + "z1": (["y", "x"], np.random.randn(2, 8)), + "z2": (["time", "y"], np.random.randn(10, 2)), + }, + { + "x": ("x", np.linspace(0, 1.0, 8)), + "time": ("time", np.linspace(0, 1.0, 10)), + "c": ("y", ["a", "b"]), + "y": range(2), + }, + ) if request.param == 2: - return Dataset({'z1': (['time', 'y'], np.random.randn(10, 2)), - 'z2': (['time'], np.random.randn(10)), - 'z3': (['x', 'time'], np.random.randn(8, 10))}, - {'x': ('x', np.linspace(0, 1.0, 8)), - 'time': ('time', np.linspace(0, 1.0, 10)), - 'c': ('y', ['a', 'b']), - 'y': range(2)}) - - -@pytest.mark.parametrize('dask', [True, False]) -@pytest.mark.parametrize(('boundary', 'side'), [ - ('trim', 'left'), ('pad', 'right')]) + return Dataset( + { + "z1": (["time", "y"], np.random.randn(10, 2)), + "z2": (["time"], np.random.randn(10)), + "z3": (["x", "time"], np.random.randn(8, 10)), + }, + { + "x": ("x", np.linspace(0, 1.0, 8)), + "time": ("time", np.linspace(0, 1.0, 10)), + "c": ("y", ["a", "b"]), + "y": range(2), + }, + ) + + +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) def test_coarsen(ds, dask, boundary, side): if dask and has_dask: - ds = ds.chunk({'x': 4}) + ds = ds.chunk({"x": 4}) actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() assert_equal( - actual['z1'], - ds['z1'].coarsen(time=2, x=3, boundary=boundary, side=side).max()) + actual["z1"], ds["z1"].coarsen(time=2, x=3, boundary=boundary, side=side).max() + ) # coordinate should be mean by default - assert_equal(actual['time'], ds['time'].coarsen( - time=2, x=3, boundary=boundary, side=side).mean()) + assert_equal( + actual["time"], + ds["time"].coarsen(time=2, x=3, boundary=boundary, side=side).mean(), + ) -@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize("dask", [True, False]) def test_coarsen_coords(ds, dask): if dask and has_dask: - ds = ds.chunk({'x': 4}) + ds = ds.chunk({"x": 4}) # check if coord_func works - actual = ds.coarsen(time=2, x=3, boundary='trim', - coord_func={'time': 'max'}).max() - assert_equal(actual['z1'], - ds['z1'].coarsen(time=2, x=3, boundary='trim').max()) - assert_equal(actual['time'], - ds['time'].coarsen(time=2, x=3, boundary='trim').max()) + actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() + assert_equal(actual["z1"], ds["z1"].coarsen(time=2, x=3, boundary="trim").max()) + assert_equal(actual["time"], ds["time"].coarsen(time=2, x=3, boundary="trim").max()) # raise if exact with pytest.raises(ValueError): ds.coarsen(x=3).mean() # should be no error - ds.isel(x=slice(0, 3 * (len(ds['x']) // 3))).coarsen(x=3).mean() + ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() # working test with pd.time da = xr.DataArray( - np.linspace(0, 365, num=364), dims='time', - coords={'time': pd.date_range('15/12/1999', periods=364)}) + np.linspace(0, 365, num=364), + dims="time", + coords={"time": pd.date_range("15/12/1999", periods=364)}, + ) actual = da.coarsen(time=2).mean() @requires_cftime def test_coarsen_coords_cftime(): - times = xr.cftime_range('2000', periods=6) - da = xr.DataArray(range(6), [('time', times)]) + times = xr.cftime_range("2000", periods=6) + da = xr.DataArray(range(6), [("time", times)]) actual = da.coarsen(time=3).mean() - expected_times = xr.cftime_range('2000-01-02', freq='3D', periods=2) + expected_times = xr.cftime_range("2000-01-02", freq="3D", periods=2) np.testing.assert_array_equal(actual.time, expected_times) def test_rolling_properties(ds): # catching invalid args - with pytest.raises(ValueError, match='exactly one dim/window should'): + with pytest.raises(ValueError, match="exactly one dim/window should"): ds.rolling(time=7, x=2) - with pytest.raises(ValueError, match='window must be > 0'): + with pytest.raises(ValueError, match="window must be > 0"): ds.rolling(time=-2) - with pytest.raises( - ValueError, match='min_periods must be greater than zero' - ): + with pytest.raises(ValueError, match="min_periods must be greater than zero"): ds.rolling(time=2, min_periods=0) - with pytest.raises(KeyError, match='time2'): + with pytest.raises(KeyError, match="time2"): ds.rolling(time2=2) -@pytest.mark.parametrize('name', - ('sum', 'mean', 'std', 'var', 'min', 'max', 'median')) -@pytest.mark.parametrize('center', (True, False, None)) -@pytest.mark.parametrize('min_periods', (1, None)) -@pytest.mark.parametrize('key', ('z1', 'z2')) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) +@pytest.mark.parametrize("center", (True, False, None)) +@pytest.mark.parametrize("min_periods", (1, None)) +@pytest.mark.parametrize("key", ("z1", "z2")) def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): - bn = pytest.importorskip('bottleneck', minversion='1.1') + bn = pytest.importorskip("bottleneck", minversion="1.1") # Test all bottleneck functions rolling_obj = ds.rolling(time=7, min_periods=min_periods) - func_name = 'move_{}'.format(name) + func_name = "move_{}".format(name) actual = getattr(rolling_obj, name)() - if key == 'z1': # z1 does not depend on 'Time' axis. Stored as it is. + if key == "z1": # z1 does not depend on 'Time' axis. Stored as it is. expected = ds[key] - elif key == 'z2': - expected = getattr(bn, func_name)(ds[key].values, window=7, axis=0, - min_count=min_periods) + elif key == "z2": + expected = getattr(bn, func_name)( + ds[key].values, window=7, axis=0, min_count=min_periods + ) assert_array_equal(actual[key].values, expected) # Test center rolling_obj = ds.rolling(time=7, center=center) - actual = getattr(rolling_obj, name)()['time'] - assert_equal(actual, ds['time']) + actual = getattr(rolling_obj, name)()["time"] + assert_equal(actual, ds["time"]) @requires_numbagg def test_rolling_exp(ds): - result = ds.rolling_exp(time=10, window_type='span').mean() + result = ds.rolling_exp(time=10, window_type="span").mean() assert isinstance(result, Dataset) -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) def test_rolling_pandas_compat(center, window, min_periods): - df = pd.DataFrame({'x': np.random.randn(20), 'y': np.random.randn(20), - 'time': np.linspace(0, 1, 20)}) + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) ds = Dataset.from_dataframe(df) if min_periods is not None and window < min_periods: min_periods = window - df_rolling = df.rolling(window, center=center, - min_periods=min_periods).mean() - ds_rolling = ds.rolling(index=window, center=center, - min_periods=min_periods).mean() + df_rolling = df.rolling(window, center=center, min_periods=min_periods).mean() + ds_rolling = ds.rolling(index=window, center=center, min_periods=min_periods).mean() - np.testing.assert_allclose(df_rolling['x'].values, ds_rolling['x'].values) - np.testing.assert_allclose(df_rolling.index, ds_rolling['index']) + np.testing.assert_allclose(df_rolling["x"].values, ds_rolling["x"].values) + np.testing.assert_allclose(df_rolling.index, ds_rolling["index"]) -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) def test_rolling_construct(center, window): - df = pd.DataFrame({'x': np.random.randn(20), 'y': np.random.randn(20), - 'time': np.linspace(0, 1, 20)}) + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) ds = Dataset.from_dataframe(df) df_rolling = df.rolling(window, center=center, min_periods=1).mean() ds_rolling = ds.rolling(index=window, center=center) - ds_rolling_mean = ds_rolling.construct('window').mean('window') - np.testing.assert_allclose(df_rolling['x'].values, - ds_rolling_mean['x'].values) - np.testing.assert_allclose(df_rolling.index, ds_rolling_mean['index']) + ds_rolling_mean = ds_rolling.construct("window").mean("window") + np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) + np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) # with stride - ds_rolling_mean = ds_rolling.construct('window', stride=2).mean('window') - np.testing.assert_allclose(df_rolling['x'][::2].values, - ds_rolling_mean['x'].values) - np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean['index']) + ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window") + np.testing.assert_allclose(df_rolling["x"][::2].values, ds_rolling_mean["x"].values) + np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"]) # with fill_value - ds_rolling_mean = ds_rolling.construct( - 'window', stride=2, fill_value=0.0).mean('window') - assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim='vars').all() - assert (ds_rolling_mean['x'] == 0.0).sum() >= 0 + ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean( + "window" + ) + assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all() + assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 @pytest.mark.slow -@pytest.mark.parametrize('ds', (1, 2), indirect=True) -@pytest.mark.parametrize('center', (True, False)) -@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) -@pytest.mark.parametrize('window', (1, 2, 3, 4)) -@pytest.mark.parametrize('name', - ('sum', 'mean', 'std', 'var', 'min', 'max', 'median')) +@pytest.mark.parametrize("ds", (1, 2), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) def test_rolling_reduce(ds, center, min_periods, window, name): if min_periods is not None and window < min_periods: min_periods = window - if name == 'std' and window == 1: - pytest.skip('std with window == 1 is unstable in bottleneck') + if name == "std" and window == 1: + pytest.skip("std with window == 1 is unstable in bottleneck") - rolling_obj = ds.rolling(time=window, center=center, - min_periods=min_periods) + rolling_obj = ds.rolling(time=window, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar behavior as bottleneck - actual = rolling_obj.reduce(getattr(np, 'nan%s' % name)) + actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) assert ds.dims == actual.dims @@ -5091,189 +5466,236 @@ def test_rolling_reduce(ds, center, min_periods, window, name): def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: - Dataset(data_vars={'x': ('y', [1, 2, np.NaN])}) > 0 + Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 assert len(record) == 0 -@pytest.mark.parametrize('dask', [True, False]) -@pytest.mark.parametrize('edge_order', [1, 2]) +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize("edge_order", [1, 2]) def test_differentiate(dask, edge_order): rs = np.random.RandomState(42) coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] - da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], - coords={'x': coord, - 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.randn(8, 6))}, + ) if dask and has_dask: - da = da.chunk({'x': 4}) + da = da.chunk({"x": 4}) - ds = xr.Dataset({'var': da}) + ds = xr.Dataset({"var": da}) # along x - actual = da.differentiate('x', edge_order) + actual = da.differentiate("x", edge_order) expected_x = xr.DataArray( - npcompat.gradient(da, da['x'], axis=0, edge_order=edge_order), - dims=da.dims, coords=da.coords) + npcompat.gradient(da, da["x"], axis=0, edge_order=edge_order), + dims=da.dims, + coords=da.coords, + ) assert_equal(expected_x, actual) - assert_equal(ds['var'].differentiate('x', edge_order=edge_order), - ds.differentiate('x', edge_order=edge_order)['var']) + assert_equal( + ds["var"].differentiate("x", edge_order=edge_order), + ds.differentiate("x", edge_order=edge_order)["var"], + ) # coordinate should not change - assert_equal(da['x'], actual['x']) + assert_equal(da["x"], actual["x"]) # along y - actual = da.differentiate('y', edge_order) + actual = da.differentiate("y", edge_order) expected_y = xr.DataArray( - npcompat.gradient(da, da['y'], axis=1, edge_order=edge_order), - dims=da.dims, coords=da.coords) + npcompat.gradient(da, da["y"], axis=1, edge_order=edge_order), + dims=da.dims, + coords=da.coords, + ) assert_equal(expected_y, actual) - assert_equal(actual, ds.differentiate('y', edge_order=edge_order)['var']) - assert_equal(ds['var'].differentiate('y', edge_order=edge_order), - ds.differentiate('y', edge_order=edge_order)['var']) + assert_equal(actual, ds.differentiate("y", edge_order=edge_order)["var"]) + assert_equal( + ds["var"].differentiate("y", edge_order=edge_order), + ds.differentiate("y", edge_order=edge_order)["var"], + ) with pytest.raises(ValueError): - da.differentiate('x2d') + da.differentiate("x2d") -@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize("dask", [True, False]) def test_differentiate_datetime(dask): rs = np.random.RandomState(42) coord = np.array( - ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', - '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'], - dtype='datetime64') + [ + "2004-07-13", + "2006-01-13", + "2010-08-13", + "2010-09-13", + "2010-10-11", + "2010-12-13", + "2011-02-13", + "2012-08-13", + ], + dtype="datetime64", + ) - da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], - coords={'x': coord, - 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.randn(8, 6))}, + ) if dask and has_dask: - da = da.chunk({'x': 4}) + da = da.chunk({"x": 4}) # along x - actual = da.differentiate('x', edge_order=1, datetime_unit='D') + actual = da.differentiate("x", edge_order=1, datetime_unit="D") expected_x = xr.DataArray( npcompat.gradient( - da, da['x'].variable._to_numeric(datetime_unit='D'), - axis=0, edge_order=1), dims=da.dims, coords=da.coords) + da, da["x"].variable._to_numeric(datetime_unit="D"), axis=0, edge_order=1 + ), + dims=da.dims, + coords=da.coords, + ) assert_equal(expected_x, actual) - actual2 = da.differentiate('x', edge_order=1, datetime_unit='h') + actual2 = da.differentiate("x", edge_order=1, datetime_unit="h") assert np.allclose(actual, actual2 * 24) # for datetime variable - actual = da['x'].differentiate('x', edge_order=1, datetime_unit='D') + actual = da["x"].differentiate("x", edge_order=1, datetime_unit="D") assert np.allclose(actual, 1.0) # with different date unit - da = xr.DataArray(coord.astype('datetime64[ms]'), dims=['x'], - coords={'x': coord}) - actual = da.differentiate('x', edge_order=1) + da = xr.DataArray(coord.astype("datetime64[ms]"), dims=["x"], coords={"x": coord}) + actual = da.differentiate("x", edge_order=1) assert np.allclose(actual, 1.0) -@pytest.mark.skipif(not has_cftime, reason='Test requires cftime.') -@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.skipif(not has_cftime, reason="Test requires cftime.") +@pytest.mark.parametrize("dask", [True, False]) def test_differentiate_cftime(dask): rs = np.random.RandomState(42) - coord = xr.cftime_range('2000', periods=8, freq='2M') + coord = xr.cftime_range("2000", periods=8, freq="2M") da = xr.DataArray( rs.randn(8, 6), - coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))}, - dims=['time', 'y']) + coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.randn(8, 6))}, + dims=["time", "y"], + ) if dask and has_dask: - da = da.chunk({'time': 4}) + da = da.chunk({"time": 4}) - actual = da.differentiate('time', edge_order=1, datetime_unit='D') + actual = da.differentiate("time", edge_order=1, datetime_unit="D") expected_data = npcompat.gradient( - da, da['time'].variable._to_numeric(datetime_unit='D'), - axis=0, edge_order=1) + da, da["time"].variable._to_numeric(datetime_unit="D"), axis=0, edge_order=1 + ) expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims) assert_equal(expected, actual) - actual2 = da.differentiate('time', edge_order=1, datetime_unit='h') + actual2 = da.differentiate("time", edge_order=1, datetime_unit="h") assert_allclose(actual, actual2 * 24) # Test the differentiation of datetimes themselves - actual = da['time'].differentiate('time', edge_order=1, datetime_unit='D') - assert_allclose(actual, xr.ones_like(da['time']).astype(float)) + actual = da["time"].differentiate("time", edge_order=1, datetime_unit="D") + assert_allclose(actual, xr.ones_like(da["time"]).astype(float)) -@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize("dask", [True, False]) def test_integrate(dask): rs = np.random.RandomState(42) coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] - da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], - coords={'x': coord, 'x2': (('x', ), rs.randn(8)), - 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={ + "x": coord, + "x2": (("x",), rs.randn(8)), + "z": 3, + "x2d": (("x", "y"), rs.randn(8, 6)), + }, + ) if dask and has_dask: - da = da.chunk({'x': 4}) + da = da.chunk({"x": 4}) - ds = xr.Dataset({'var': da}) + ds = xr.Dataset({"var": da}) # along x - actual = da.integrate('x') + actual = da.integrate("x") # coordinate that contains x should be dropped. expected_x = xr.DataArray( - np.trapz(da, da['x'], axis=0), dims=['y'], - coords={k: v for k, v in da.coords.items() if 'x' not in v.dims}) + np.trapz(da, da["x"], axis=0), + dims=["y"], + coords={k: v for k, v in da.coords.items() if "x" not in v.dims}, + ) assert_allclose(expected_x, actual.compute()) - assert_equal(ds['var'].integrate('x'), ds.integrate('x')['var']) + assert_equal(ds["var"].integrate("x"), ds.integrate("x")["var"]) # make sure result is also a dask array (if the source is dask array) assert isinstance(actual.data, type(da.data)) # along y - actual = da.integrate('y') + actual = da.integrate("y") expected_y = xr.DataArray( - np.trapz(da, da['y'], axis=1), dims=['x'], - coords={k: v for k, v in da.coords.items() if 'y' not in v.dims}) + np.trapz(da, da["y"], axis=1), + dims=["x"], + coords={k: v for k, v in da.coords.items() if "y" not in v.dims}, + ) assert_allclose(expected_y, actual.compute()) - assert_equal(actual, ds.integrate('y')['var']) - assert_equal(ds['var'].integrate('y'), ds.integrate('y')['var']) + assert_equal(actual, ds.integrate("y")["var"]) + assert_equal(ds["var"].integrate("y"), ds.integrate("y")["var"]) # along x and y - actual = da.integrate(('y', 'x')) + actual = da.integrate(("y", "x")) assert actual.ndim == 0 with pytest.raises(ValueError): - da.integrate('x2d') + da.integrate("x2d") -@pytest.mark.parametrize('dask', [True, False]) -@pytest.mark.parametrize('which_datetime', ['np', 'cftime']) +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize("which_datetime", ["np", "cftime"]) def test_trapz_datetime(dask, which_datetime): rs = np.random.RandomState(42) - if which_datetime == 'np': + if which_datetime == "np": coord = np.array( - ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', - '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'], - dtype='datetime64') + [ + "2004-07-13", + "2006-01-13", + "2010-08-13", + "2010-09-13", + "2010-10-11", + "2010-12-13", + "2011-02-13", + "2012-08-13", + ], + dtype="datetime64", + ) else: if not has_cftime: - pytest.skip('Test requires cftime.') - coord = xr.cftime_range('2000', periods=8, freq='2D') + pytest.skip("Test requires cftime.") + coord = xr.cftime_range("2000", periods=8, freq="2D") da = xr.DataArray( rs.randn(8, 6), - coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))}, - dims=['time', 'y']) + coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.randn(8, 6))}, + dims=["time", "y"], + ) if dask and has_dask: - da = da.chunk({'time': 4}) + da = da.chunk({"time": 4}) - actual = da.integrate('time', datetime_unit='D') + actual = da.integrate("time", datetime_unit="D") expected_data = np.trapz( - da, duck_array_ops.datetime_to_numeric(da['time'], datetime_unit='D'), - axis=0) + da, duck_array_ops.datetime_to_numeric(da["time"], datetime_unit="D"), axis=0 + ) expected = xr.DataArray( - expected_data, dims=['y'], - coords={k: v for k, v in da.coords.items() if 'time' not in v.dims}) + expected_data, + dims=["y"], + coords={k: v for k, v in da.coords.items() if "time" not in v.dims}, + ) assert_allclose(expected, actual.compute()) # make sure result is also a dask array (if the source is dask array) assert isinstance(actual.data, type(da.data)) - actual2 = da.integrate('time', datetime_unit='h') + actual2 = da.integrate("time", datetime_unit="h") assert_allclose(actual, actual2 / 24.0) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 98c53ef2b12..0929efc56f2 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -4,9 +4,8 @@ import pytest -dask = pytest.importorskip('dask', minversion='0.18') # isort:skip -distributed = pytest.importorskip( - 'distributed', minversion='1.21') # isort:skip +dask = pytest.importorskip("dask", minversion="0.18") # isort:skip +distributed = pytest.importorskip("distributed", minversion="1.21") # isort:skip from dask.distributed import Client, Lock from distributed.utils_test import cluster, gen_cluster @@ -15,122 +14,142 @@ import xarray as xr from xarray.backends.locks import HDF5_LOCK, CombinedLock -from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, - create_tmp_geotiff, - open_example_dataset) +from xarray.tests.test_backends import ( + ON_WINDOWS, + create_tmp_file, + create_tmp_geotiff, + open_example_dataset, +) from xarray.tests.test_dataset import create_test_data from . import ( - assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, - requires_zarr, requires_cfgrib) + assert_allclose, + has_h5netcdf, + has_netCDF4, + requires_rasterio, + has_scipy, + requires_zarr, + requires_cfgrib, +) # this is to stop isort throwing errors. May have been easier to just use # `isort:skip` in retrospect -da = pytest.importorskip('dask.array') +da = pytest.importorskip("dask.array") loop = loop # loop is an imported fixture, which flake8 has issues ack-ing @pytest.fixture def tmp_netcdf_filename(tmpdir): - return str(tmpdir.join('testfile.nc')) + return str(tmpdir.join("testfile.nc")) ENGINES = [] if has_scipy: - ENGINES.append('scipy') + ENGINES.append("scipy") if has_netCDF4: - ENGINES.append('netcdf4') + ENGINES.append("netcdf4") if has_h5netcdf: - ENGINES.append('h5netcdf') - -NC_FORMATS = {'netcdf4': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT_OFFSET', - 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], - 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], - 'h5netcdf': ['NETCDF4']} + ENGINES.append("h5netcdf") + +NC_FORMATS = { + "netcdf4": [ + "NETCDF3_CLASSIC", + "NETCDF3_64BIT_OFFSET", + "NETCDF3_64BIT_DATA", + "NETCDF4_CLASSIC", + "NETCDF4", + ], + "scipy": ["NETCDF3_CLASSIC", "NETCDF3_64BIT"], + "h5netcdf": ["NETCDF4"], +} ENGINES_AND_FORMATS = [ - ('netcdf4', 'NETCDF3_CLASSIC'), - ('netcdf4', 'NETCDF4_CLASSIC'), - ('netcdf4', 'NETCDF4'), - ('h5netcdf', 'NETCDF4'), - ('scipy', 'NETCDF3_64BIT'), + ("netcdf4", "NETCDF3_CLASSIC"), + ("netcdf4", "NETCDF4_CLASSIC"), + ("netcdf4", "NETCDF4"), + ("h5netcdf", "NETCDF4"), + ("scipy", "NETCDF3_64BIT"), ] -@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) # noqa +@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) # noqa def test_dask_distributed_netcdf_roundtrip( - loop, tmp_netcdf_filename, engine, nc_format): + loop, tmp_netcdf_filename, engine, nc_format +): if engine not in ENGINES: - pytest.skip('engine not available') + pytest.skip("engine not available") - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + chunks = {"dim1": 4, "dim2": 3, "dim3": 6} with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): + with Client(s["address"], loop=loop): original = create_test_data().chunk(chunks) - if engine == 'scipy': + if engine == "scipy": with pytest.raises(NotImplementedError): - original.to_netcdf(tmp_netcdf_filename, - engine=engine, format=nc_format) + original.to_netcdf( + tmp_netcdf_filename, engine=engine, format=nc_format + ) return - original.to_netcdf(tmp_netcdf_filename, - engine=engine, format=nc_format) + original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) - with xr.open_dataset(tmp_netcdf_filename, - chunks=chunks, engine=engine) as restored: + with xr.open_dataset( + tmp_netcdf_filename, chunks=chunks, engine=engine + ) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) -@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) # noqa +@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) # noqa def test_dask_distributed_read_netcdf_integration_test( - loop, tmp_netcdf_filename, engine, nc_format): + loop, tmp_netcdf_filename, engine, nc_format +): if engine not in ENGINES: - pytest.skip('engine not available') + pytest.skip("engine not available") - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + chunks = {"dim1": 4, "dim2": 3, "dim3": 6} with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): + with Client(s["address"], loop=loop): original = create_test_data() - original.to_netcdf(tmp_netcdf_filename, - engine=engine, format=nc_format) + original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) - with xr.open_dataset(tmp_netcdf_filename, - chunks=chunks, - engine=engine) as restored: + with xr.open_dataset( + tmp_netcdf_filename, chunks=chunks, engine=engine + ) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) @requires_zarr # noqar -@pytest.mark.parametrize('consolidated', [True, False]) -@pytest.mark.parametrize('compute', [True, False]) +@pytest.mark.parametrize("consolidated", [True, False]) +@pytest.mark.parametrize("compute", [True, False]) def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): if consolidated: - pytest.importorskip('zarr', minversion="2.2.1.dev2") + pytest.importorskip("zarr", minversion="2.2.1.dev2") write_kwargs = dict(consolidated=True) read_kwargs = dict(consolidated=True) else: write_kwargs = read_kwargs = {} - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 5} + chunks = {"dim1": 4, "dim2": 3, "dim3": 5} with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): + with Client(s["address"], loop=loop): original = create_test_data().chunk(chunks) - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS, - suffix='.zarrc') as filename: - maybe_futures = original.to_zarr(filename, compute=compute, - **write_kwargs) + with create_tmp_file( + allow_cleanup_failure=ON_WINDOWS, suffix=".zarrc" + ) as filename: + maybe_futures = original.to_zarr( + filename, compute=compute, **write_kwargs + ) if not compute: maybe_futures.compute() with xr.open_zarr(filename, **read_kwargs) as restored: @@ -143,8 +162,8 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): def test_dask_distributed_rasterio_integration_test(loop): with create_tmp_geotiff() as (tmp_file, expected): with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): - da_tiff = xr.open_rasterio(tmp_file, chunks={'band': 1}) + with Client(s["address"], loop=loop): + da_tiff = xr.open_rasterio(tmp_file, chunks={"band": 1}) assert isinstance(da_tiff.data, da.Array) actual = da_tiff.compute() assert_allclose(actual, expected) @@ -153,24 +172,25 @@ def test_dask_distributed_rasterio_integration_test(loop): @requires_cfgrib # noqa def test_dask_distributed_cfgrib_integration_test(loop): with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): - with open_example_dataset('example.grib', - engine='cfgrib', - chunks={'time': 1}) as ds: - with open_example_dataset('example.grib', - engine='cfgrib') as expected: - assert isinstance(ds['t'].data, da.Array) + with Client(s["address"], loop=loop): + with open_example_dataset( + "example.grib", engine="cfgrib", chunks={"time": 1} + ) as ds: + with open_example_dataset("example.grib", engine="cfgrib") as expected: + assert isinstance(ds["t"].data, da.Array) actual = ds.compute() assert_allclose(actual, expected) -@pytest.mark.skipif(distributed.__version__ <= '1.19.3', - reason='Need recent distributed version to clean up get') +@pytest.mark.skipif( + distributed.__version__ <= "1.19.3", + reason="Need recent distributed version to clean up get", +) @gen_cluster(client=True, timeout=None) def test_async(c, s, a, b): x = create_test_data() assert not dask.is_dask_collection(x) - y = x.chunk({'dim2': 4}) + 10 + y = x.chunk({"dim2": 4}) + 10 assert dask.is_dask_collection(y) assert dask.is_dask_collection(y.var1) assert dask.is_dask_collection(y.var2) @@ -205,9 +225,13 @@ def f(x, lock=None): return x + 1 # note, the creation of Lock needs to be done inside a cluster - for lock in [HDF5_LOCK, Lock(), Lock('filename.nc'), - CombinedLock([HDF5_LOCK]), - CombinedLock([HDF5_LOCK, Lock('filename.nc')])]: + for lock in [ + HDF5_LOCK, + Lock(), + Lock("filename.nc"), + CombinedLock([HDF5_LOCK]), + CombinedLock([HDF5_LOCK, Lock("filename.nc")]), + ]: futures = c.map(f, list(range(10)), lock=lock) yield c.gather(futures) diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 260486df275..1f3aee84979 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -4,15 +4,18 @@ from xarray.core import dtypes -@pytest.mark.parametrize("args, expected", [ - ([np.bool], np.bool), - ([np.bool, np.string_], np.object_), - ([np.float32, np.float64], np.float64), - ([np.float32, np.string_], np.object_), - ([np.unicode_, np.int64], np.object_), - ([np.unicode_, np.unicode_], np.unicode_), - ([np.bytes_, np.unicode_], np.object_), -]) +@pytest.mark.parametrize( + "args, expected", + [ + ([np.bool], np.bool), + ([np.bool, np.string_], np.object_), + ([np.float32, np.float64], np.float64), + ([np.float32, np.string_], np.object_), + ([np.unicode_, np.int64], np.object_), + ([np.unicode_, np.unicode_], np.unicode_), + ([np.bytes_, np.unicode_], np.object_), + ], +) def test_result_type(args, expected): actual = dtypes.result_type(*args) assert actual == expected @@ -25,8 +28,8 @@ def test_result_type_scalar(): def test_result_type_dask_array(): # verify it works without evaluating dask arrays - da = pytest.importorskip('dask.array') - dask = pytest.importorskip('dask') + da = pytest.importorskip("dask.array") + dask = pytest.importorskip("dask") def error(): raise RuntimeError @@ -44,39 +47,42 @@ def error(): assert actual == np.float64 -@pytest.mark.parametrize('obj', [1.0, np.inf, 'ab', 1.0 + 1.0j, True]) +@pytest.mark.parametrize("obj", [1.0, np.inf, "ab", 1.0 + 1.0j, True]) def test_inf(obj): assert dtypes.INF > obj assert dtypes.NINF < obj -@pytest.mark.parametrize("kind, expected", [ - ('a', (np.dtype('O'), 'nan')), # dtype('S') - ('b', (np.float32, 'nan')), # dtype('int8') - ('B', (np.float32, 'nan')), # dtype('uint8') - ('c', (np.dtype('O'), 'nan')), # dtype('S1') - ('D', (np.complex128, '(nan+nanj)')), # dtype('complex128') - ('d', (np.float64, 'nan')), # dtype('float64') - ('e', (np.float16, 'nan')), # dtype('float16') - ('F', (np.complex64, '(nan+nanj)')), # dtype('complex64') - ('f', (np.float32, 'nan')), # dtype('float32') - ('h', (np.float32, 'nan')), # dtype('int16') - ('H', (np.float32, 'nan')), # dtype('uint16') - ('i', (np.float64, 'nan')), # dtype('int32') - ('I', (np.float64, 'nan')), # dtype('uint32') - ('l', (np.float64, 'nan')), # dtype('int64') - ('L', (np.float64, 'nan')), # dtype('uint64') - ('m', (np.timedelta64, 'NaT')), # dtype('= LooseVersion('0.22.0'): + if LooseVersion(pd.__version__) >= LooseVersion("0.22.0"): # min_count is only implenented in pandas > 0.22 - expected = series_reduce(da, func, skipna=True, dim=aggdim, - min_count=min_count) + expected = series_reduce(da, func, skipna=True, dim=aggdim, min_count=min_count) assert_allclose(actual, expected) assert_dask_array(actual, dask) -@pytest.mark.parametrize('func', ['sum', 'prod']) +@pytest.mark.parametrize("func", ["sum", "prod"]) def test_min_count_dataset(func): da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) - ds = Dataset({'var1': da}, coords={'scalar': 0}) - actual = getattr(ds, func)(dim='x', skipna=True, min_count=3)['var1'] - expected = getattr(ds['var1'], func)(dim='x', skipna=True, min_count=3) + ds = Dataset({"var1": da}, coords={"scalar": 0}) + actual = getattr(ds, func)(dim="x", skipna=True, min_count=3)["var1"] + expected = getattr(ds["var1"], func)(dim="x", skipna=True, min_count=3) assert_allclose(actual, expected) -@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) -@pytest.mark.parametrize('dask', [False, True]) -@pytest.mark.parametrize('func', ['sum', 'prod']) +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) def test_multiple_dims(dtype, dask, func): if dask and not has_dask: - pytest.skip('requires dask') + pytest.skip("requires dask") da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) - actual = getattr(da, func)(('x', 'y')) - expected = getattr(getattr(da, func)('x'), func)('y') + actual = getattr(da, func)(("x", "y")) + expected = getattr(getattr(da, func)("x"), func)("y") assert_allclose(actual, expected) def test_docs(): # with min_count actual = DataArray.sum.__doc__ - expected = dedent("""\ + expected = dedent( + """\ Reduce this DataArray's data by applying `sum` along some dimension(s). Parameters @@ -579,12 +599,14 @@ def test_docs(): reduced : DataArray New DataArray object with `sum` applied to its data and the indicated dimension(s) removed. - """) + """ + ) assert actual == expected # without min_count actual = DataArray.std.__doc__ - expected = dedent("""\ + expected = dedent( + """\ Reduce this DataArray's data by applying `std` along some dimension(s). Parameters @@ -613,44 +635,41 @@ def test_docs(): reduced : DataArray New DataArray object with `std` applied to its data and the indicated dimension(s) removed. - """) + """ + ) assert actual == expected def test_datetime_to_numeric_datetime64(): - times = pd.date_range('2000', periods=5, freq='7D').values - result = duck_array_ops.datetime_to_numeric(times, datetime_unit='h') + times = pd.date_range("2000", periods=5, freq="7D").values + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h") expected = 24 * np.arange(0, 35, 7) np.testing.assert_array_equal(result, expected) offset = times[1] - result = duck_array_ops.datetime_to_numeric( - times, offset=offset, datetime_unit='h') + result = duck_array_ops.datetime_to_numeric(times, offset=offset, datetime_unit="h") expected = 24 * np.arange(-7, 28, 7) np.testing.assert_array_equal(result, expected) dtype = np.float32 - result = duck_array_ops.datetime_to_numeric( - times, datetime_unit='h', dtype=dtype) + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype) expected = 24 * np.arange(0, 35, 7).astype(dtype) np.testing.assert_array_equal(result, expected) @requires_cftime def test_datetime_to_numeric_cftime(): - times = cftime_range('2000', periods=5, freq='7D').values - result = duck_array_ops.datetime_to_numeric(times, datetime_unit='h') + times = cftime_range("2000", periods=5, freq="7D").values + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h") expected = 24 * np.arange(0, 35, 7) np.testing.assert_array_equal(result, expected) offset = times[1] - result = duck_array_ops.datetime_to_numeric( - times, offset=offset, datetime_unit='h') + result = duck_array_ops.datetime_to_numeric(times, offset=offset, datetime_unit="h") expected = 24 * np.arange(-7, 28, 7) np.testing.assert_array_equal(result, expected) dtype = np.float32 - result = duck_array_ops.datetime_to_numeric( - times, datetime_unit='h', dtype=dtype) + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype) expected = 24 * np.arange(0, 35, 7).astype(dtype) np.testing.assert_array_equal(result, expected) diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index e67e7a0f6a0..5af0f6d8a42 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -7,8 +7,8 @@ from . import raises_regex -@xr.register_dataset_accessor('example_accessor') -@xr.register_dataarray_accessor('example_accessor') +@xr.register_dataset_accessor("example_accessor") +@xr.register_dataarray_accessor("example_accessor") class ExampleAccessor: """For the pickling tests below.""" @@ -18,9 +18,8 @@ def __init__(self, xarray_obj): class TestAccessor: def test_register(self): - - @xr.register_dataset_accessor('demo') - @xr.register_dataarray_accessor('demo') + @xr.register_dataset_accessor("demo") + @xr.register_dataarray_accessor("demo") class DemoAccessor: """Demo accessor.""" @@ -29,13 +28,13 @@ def __init__(self, xarray_obj): @property def foo(self): - return 'bar' + return "bar" ds = xr.Dataset() - assert ds.demo.foo == 'bar' + assert ds.demo.foo == "bar" da = xr.DataArray(0) - assert da.demo.foo == 'bar' + assert da.demo.foo == "bar" # accessor is cached assert ds.demo is ds.demo @@ -48,15 +47,16 @@ def foo(self): # ensure we can remove it del xr.Dataset.demo - assert not hasattr(xr.Dataset, 'demo') + assert not hasattr(xr.Dataset, "demo") + + with pytest.warns(Warning, match="overriding a preexisting attribute"): - with pytest.warns(Warning, match='overriding a preexisting attribute'): - @xr.register_dataarray_accessor('demo') + @xr.register_dataarray_accessor("demo") class Foo: pass # it didn't get registered again - assert not hasattr(xr.Dataset, 'demo') + assert not hasattr(xr.Dataset, "demo") def test_pickle_dataset(self): ds = xr.Dataset() @@ -65,10 +65,10 @@ def test_pickle_dataset(self): # state save on the accessor is restored assert ds.example_accessor is ds.example_accessor - ds.example_accessor.value = 'foo' + ds.example_accessor.value = "foo" ds_restored = pickle.loads(pickle.dumps(ds)) assert ds.identical(ds_restored) - assert ds_restored.example_accessor.value == 'foo' + assert ds_restored.example_accessor.value == "foo" def test_pickle_dataarray(self): array = xr.Dataset() @@ -79,10 +79,10 @@ def test_pickle_dataarray(self): def test_broken_accessor(self): # regression test for GH933 - @xr.register_dataset_accessor('stupid_accessor') + @xr.register_dataset_accessor("stupid_accessor") class BrokenAccessor: def __init__(self, xarray_obj): - raise AttributeError('broken') + raise AttributeError("broken") - with raises_regex(RuntimeError, 'error initializing'): + with raises_regex(RuntimeError, "error initializing"): xr.Dataset().stupid_accessor diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index be3e368e02b..02b13fd5e0e 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -12,33 +12,35 @@ class TestFormatting: - def test_get_indexer_at_least_n_items(self): cases = [ ((20,), (slice(10),), (slice(-10, None),)), - ((3, 20,), (0, slice(10)), (-1, slice(-10, None))), - ((2, 10,), (0, slice(10)), (-1, slice(-10, None))), - ((2, 5,), (slice(2), slice(None)), - (slice(-2, None), slice(None))), - ((1, 2, 5,), (0, slice(2), slice(None)), - (-1, slice(-2, None), slice(None))), - ((2, 3, 5,), (0, slice(2), slice(None)), - (-1, slice(-2, None), slice(None))), - ((1, 10, 1,), (0, slice(10), slice(None)), - (-1, slice(-10, None), slice(None))), - ((2, 5, 1,), (slice(2), slice(None), slice(None)), - (slice(-2, None), slice(None), slice(None))), - ((2, 5, 3,), (0, slice(4), slice(None)), - (-1, slice(-4, None), slice(None))), - ((2, 3, 3,), (slice(2), slice(None), slice(None)), - (slice(-2, None), slice(None), slice(None))), + ((3, 20), (0, slice(10)), (-1, slice(-10, None))), + ((2, 10), (0, slice(10)), (-1, slice(-10, None))), + ((2, 5), (slice(2), slice(None)), (slice(-2, None), slice(None))), + ((1, 2, 5), (0, slice(2), slice(None)), (-1, slice(-2, None), slice(None))), + ((2, 3, 5), (0, slice(2), slice(None)), (-1, slice(-2, None), slice(None))), + ( + (1, 10, 1), + (0, slice(10), slice(None)), + (-1, slice(-10, None), slice(None)), + ), + ( + (2, 5, 1), + (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None)), + ), + ((2, 5, 3), (0, slice(4), slice(None)), (-1, slice(-4, None), slice(None))), + ( + (2, 3, 3), + (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None)), + ), ] for shape, start_expected, end_expected in cases: - actual = formatting._get_indexer_at_least_n_items(shape, 10, - from_end=False) + actual = formatting._get_indexer_at_least_n_items(shape, 10, from_end=False) assert start_expected == actual - actual = formatting._get_indexer_at_least_n_items(shape, 10, - from_end=True) + actual = formatting._get_indexer_at_least_n_items(shape, 10, from_end=True) assert end_expected == actual def test_first_n_items(self): @@ -48,7 +50,7 @@ def test_first_n_items(self): expected = array.flat[:n] assert (expected == actual).all() - with raises_regex(ValueError, 'at least one item'): + with raises_regex(ValueError, "at least one item"): formatting.first_n_items(array, 0) def test_last_n_items(self): @@ -58,7 +60,7 @@ def test_last_n_items(self): expected = array.flat[-n:] assert (expected == actual).all() - with raises_regex(ValueError, 'at least one item'): + with raises_regex(ValueError, "at least one item"): formatting.first_n_items(array, 0) def test_last_item(self): @@ -73,17 +75,17 @@ def test_last_item(self): def test_format_item(self): cases = [ - (pd.Timestamp('2000-01-01T12'), '2000-01-01T12:00:00'), - (pd.Timestamp('2000-01-01'), '2000-01-01'), - (pd.Timestamp('NaT'), 'NaT'), - (pd.Timedelta('10 days 1 hour'), '10 days 01:00:00'), - (pd.Timedelta('-3 days'), '-3 days +00:00:00'), - (pd.Timedelta('3 hours'), '0 days 03:00:00'), - (pd.Timedelta('NaT'), 'NaT'), - ('foo', "'foo'"), - (b'foo', "b'foo'"), - (1, '1'), - (1.0, '1.0'), + (pd.Timestamp("2000-01-01T12"), "2000-01-01T12:00:00"), + (pd.Timestamp("2000-01-01"), "2000-01-01"), + (pd.Timestamp("NaT"), "NaT"), + (pd.Timedelta("10 days 1 hour"), "10 days 01:00:00"), + (pd.Timedelta("-3 days"), "-3 days +00:00:00"), + (pd.Timedelta("3 hours"), "0 days 03:00:00"), + (pd.Timedelta("NaT"), "NaT"), + ("foo", "'foo'"), + (b"foo", "b'foo'"), + (1, "1"), + (1.0, "1.0"), ] for item, expected in cases: actual = formatting.format_item(item) @@ -91,122 +93,134 @@ def test_format_item(self): def test_format_items(self): cases = [ - (np.arange(4) * np.timedelta64(1, 'D'), - '0 days 1 days 2 days 3 days'), - (np.arange(4) * np.timedelta64(3, 'h'), - '00:00:00 03:00:00 06:00:00 09:00:00'), - (np.arange(4) * np.timedelta64(500, 'ms'), - '00:00:00 00:00:00.500000 00:00:01 00:00:01.500000'), - (pd.to_timedelta(['NaT', '0s', '1s', 'NaT']), - 'NaT 00:00:00 00:00:01 NaT'), - (pd.to_timedelta(['1 day 1 hour', '1 day', '0 hours']), - '1 days 01:00:00 1 days 00:00:00 0 days 00:00:00'), - ([1, 2, 3], '1 2 3'), + (np.arange(4) * np.timedelta64(1, "D"), "0 days 1 days 2 days 3 days"), + ( + np.arange(4) * np.timedelta64(3, "h"), + "00:00:00 03:00:00 06:00:00 09:00:00", + ), + ( + np.arange(4) * np.timedelta64(500, "ms"), + "00:00:00 00:00:00.500000 00:00:01 00:00:01.500000", + ), + (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), + ( + pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), + "1 days 01:00:00 1 days 00:00:00 0 days 00:00:00", + ), + ([1, 2, 3], "1 2 3"), ] for item, expected in cases: - actual = ' '.join(formatting.format_items(item)) + actual = " ".join(formatting.format_items(item)) assert expected == actual def test_format_array_flat(self): actual = formatting.format_array_flat(np.arange(100), 2) - expected = '0 ... 99' + expected = "0 ... 99" assert expected == actual actual = formatting.format_array_flat(np.arange(100), 9) - expected = '0 ... 99' + expected = "0 ... 99" assert expected == actual actual = formatting.format_array_flat(np.arange(100), 10) - expected = '0 1 ... 99' + expected = "0 1 ... 99" assert expected == actual actual = formatting.format_array_flat(np.arange(100), 13) - expected = '0 1 ... 98 99' + expected = "0 1 ... 98 99" assert expected == actual actual = formatting.format_array_flat(np.arange(100), 15) - expected = '0 1 2 ... 98 99' + expected = "0 1 2 ... 98 99" assert expected == actual actual = formatting.format_array_flat(np.arange(100.0), 11) - expected = '0.0 ... 99.0' + expected = "0.0 ... 99.0" assert expected == actual actual = formatting.format_array_flat(np.arange(100.0), 1) - expected = '0.0 ... 99.0' + expected = "0.0 ... 99.0" assert expected == actual actual = formatting.format_array_flat(np.arange(3), 5) - expected = '0 1 2' + expected = "0 1 2" assert expected == actual actual = formatting.format_array_flat(np.arange(4.0), 11) - expected = '0.0 ... 3.0' + expected = "0.0 ... 3.0" assert expected == actual actual = formatting.format_array_flat(np.arange(0), 0) - expected = '' + expected = "" assert expected == actual actual = formatting.format_array_flat(np.arange(1), 0) - expected = '0' + expected = "0" assert expected == actual actual = formatting.format_array_flat(np.arange(2), 0) - expected = '0 1' + expected = "0 1" assert expected == actual actual = formatting.format_array_flat(np.arange(4), 0) - expected = '0 ... 3' + expected = "0 ... 3" assert expected == actual def test_pretty_print(self): - assert formatting.pretty_print('abcdefghij', 8) == 'abcde...' - assert formatting.pretty_print('ß', 1) == 'ß' + assert formatting.pretty_print("abcdefghij", 8) == "abcde..." + assert formatting.pretty_print("ß", 1) == "ß" def test_maybe_truncate(self): - assert formatting.maybe_truncate('ß', 10) == 'ß' + assert formatting.maybe_truncate("ß", 10) == "ß" def test_format_timestamp_out_of_bounds(self): from datetime import datetime + date = datetime(1300, 12, 1) - expected = '1300-12-01' + expected = "1300-12-01" result = formatting.format_timestamp(date) assert result == expected date = datetime(2300, 12, 1) - expected = '2300-12-01' + expected = "2300-12-01" result = formatting.format_timestamp(date) assert result == expected def test_attribute_repr(self): - short = formatting.summarize_attr('key', 'Short string') - long = formatting.summarize_attr('key', 100 * 'Very long string ') - newlines = formatting.summarize_attr('key', '\n\n\n') - tabs = formatting.summarize_attr('key', '\t\t\t') - assert short == ' key: Short string' + short = formatting.summarize_attr("key", "Short string") + long = formatting.summarize_attr("key", 100 * "Very long string ") + newlines = formatting.summarize_attr("key", "\n\n\n") + tabs = formatting.summarize_attr("key", "\t\t\t") + assert short == " key: Short string" assert len(long) <= 80 - assert long.endswith('...') - assert '\n' not in newlines - assert '\t' not in tabs + assert long.endswith("...") + assert "\n" not in newlines + assert "\t" not in tabs def test_diff_array_repr(self): da_a = xr.DataArray( - np.array([[1, 2, 3], [4, 5, 6]], dtype='int64'), - dims=('x', 'y'), - coords={'x': np.array(['a', 'b'], dtype='U1'), - 'y': np.array([1, 2, 3], dtype='int64')}, - attrs={'units': 'm', 'description': 'desc'}) + np.array([[1, 2, 3], [4, 5, 6]], dtype="int64"), + dims=("x", "y"), + coords={ + "x": np.array(["a", "b"], dtype="U1"), + "y": np.array([1, 2, 3], dtype="int64"), + }, + attrs={"units": "m", "description": "desc"}, + ) da_b = xr.DataArray( - np.array([1, 2], dtype='int64'), - dims='x', - coords={'x': np.array(['a', 'c'], dtype='U1'), - 'label': ('x', np.array([1, 2], dtype='int64'))}, - attrs={'units': 'kg'}) - - byteorder = '<' if sys.byteorder == 'little' else '>' - expected = dedent("""\ + np.array([1, 2], dtype="int64"), + dims="x", + coords={ + "x": np.array(["a", "c"], dtype="U1"), + "label": ("x", np.array([1, 2], dtype="int64")), + }, + attrs={"units": "kg"}, + ) + + byteorder = "<" if sys.byteorder == "little" else ">" + expected = dedent( + """\ Left and right DataArray objects are not identical Differing dimensions: (x: 2, y: 3) != (x: 2) @@ -227,21 +241,24 @@ def test_diff_array_repr(self): L units: m R units: kg Attributes only on the left object: - description: desc""" % (byteorder, byteorder)) + description: desc""" + % (byteorder, byteorder) + ) - actual = formatting.diff_array_repr(da_a, da_b, 'identical') + actual = formatting.diff_array_repr(da_a, da_b, "identical") try: assert actual == expected except AssertionError: # depending on platform, dtype may not be shown in numpy array repr assert actual == expected.replace(", dtype=int64", "") - va = xr.Variable('x', np.array([1, 2, 3], dtype='int64'), - {'title': 'test Variable'}) - vb = xr.Variable(('x', 'y'), - np.array([[1, 2, 3], [4, 5, 6]], dtype='int64')) + va = xr.Variable( + "x", np.array([1, 2, 3], dtype="int64"), {"title": "test Variable"} + ) + vb = xr.Variable(("x", "y"), np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")) - expected = dedent("""\ + expected = dedent( + """\ Left and right Variable objects are not equal Differing dimensions: (x: 3) != (x: 2, y: 3) @@ -250,9 +267,10 @@ def test_diff_array_repr(self): array([1, 2, 3], dtype=int64) R array([[1, 2, 3], - [4, 5, 6]], dtype=int64)""") + [4, 5, 6]], dtype=int64)""" + ) - actual = formatting.diff_array_repr(va, vb, 'equals') + actual = formatting.diff_array_repr(va, vb, "equals") try: assert actual == expected except AssertionError: @@ -261,26 +279,28 @@ def test_diff_array_repr(self): def test_diff_dataset_repr(self): ds_a = xr.Dataset( data_vars={ - 'var1': (('x', 'y'), - np.array([[1, 2, 3], [4, 5, 6]], dtype='int64')), - 'var2': ('x', np.array([3, 4], dtype='int64')) + "var1": (("x", "y"), np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")), + "var2": ("x", np.array([3, 4], dtype="int64")), }, - coords={'x': np.array(['a', 'b'], dtype='U1'), - 'y': np.array([1, 2, 3], dtype='int64')}, - attrs={'units': 'm', 'description': 'desc'} + coords={ + "x": np.array(["a", "b"], dtype="U1"), + "y": np.array([1, 2, 3], dtype="int64"), + }, + attrs={"units": "m", "description": "desc"}, ) ds_b = xr.Dataset( - data_vars={'var1': ('x', np.array([1, 2], dtype='int64'))}, + data_vars={"var1": ("x", np.array([1, 2], dtype="int64"))}, coords={ - 'x': ('x', np.array(['a', 'c'], dtype='U1'), {'source': 0}), - 'label': ('x', np.array([1, 2], dtype='int64')) + "x": ("x", np.array(["a", "c"], dtype="U1"), {"source": 0}), + "label": ("x", np.array([1, 2], dtype="int64")), }, - attrs={'units': 'kg'} + attrs={"units": "kg"}, ) - byteorder = '<' if sys.byteorder == 'little' else '>' - expected = dedent("""\ + byteorder = "<" if sys.byteorder == "little" else ">" + expected = dedent( + """\ Left and right Dataset objects are not identical Differing dimensions: (x: 2, y: 3) != (x: 2) @@ -301,19 +321,23 @@ def test_diff_dataset_repr(self): L units: m R units: kg Attributes only on the left object: - description: desc""" % (byteorder, byteorder)) + description: desc""" + % (byteorder, byteorder) + ) - actual = formatting.diff_dataset_repr(ds_a, ds_b, 'identical') + actual = formatting.diff_dataset_repr(ds_a, ds_b, "identical") assert actual == expected def test_array_repr(self): - ds = xr.Dataset(coords={'foo': [1, 2, 3], 'bar': [1, 2, 3]}) - ds[(1, 2)] = xr.DataArray([0], dims='test') + ds = xr.Dataset(coords={"foo": [1, 2, 3], "bar": [1, 2, 3]}) + ds[(1, 2)] = xr.DataArray([0], dims="test") actual = formatting.array_repr(ds[(1, 2)]) - expected = dedent("""\ + expected = dedent( + """\ array([0]) - Dimensions without coordinates: test""") + Dimensions without coordinates: test""" + ) assert actual == expected @@ -337,5 +361,5 @@ def test_short_array_repr(): # for default numpy repr: 167, 140, 254, 248 # for short_array_repr: 1, 7, 24, 19 for array in cases: - num_lines = formatting.short_array_repr(array).count('\n') + 1 + num_lines = formatting.short_array_repr(array).count("\n") + 1 assert num_lines < 30 diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 0f1adf3d45f..c1adf8b4fa4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -12,8 +12,7 @@ def test_consolidate_slices(): assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)] assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)] - assert (_consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)]) == - [slice(2, 6, 1)]) + assert _consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)]) == [slice(2, 6, 1)] slices = [slice(2, 3), slice(5, 6)] assert _consolidate_slices(slices) == slices @@ -24,152 +23,169 @@ def test_consolidate_slices(): def test_multi_index_groupby_apply(): # regression test for GH873 - ds = xr.Dataset({'foo': (('x', 'y'), np.random.randn(3, 4))}, - {'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]}) + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.randn(3, 4))}, + {"x": ["a", "b", "c"], "y": [1, 2, 3, 4]}, + ) doubled = 2 * ds - group_doubled = (ds.stack(space=['x', 'y']) - .groupby('space') - .apply(lambda x: 2 * x) - .unstack('space')) + group_doubled = ( + ds.stack(space=["x", "y"]) + .groupby("space") + .apply(lambda x: 2 * x) + .unstack("space") + ) assert doubled.equals(group_doubled) def test_multi_index_groupby_sum(): # regression test for GH873 - ds = xr.Dataset({'foo': (('x', 'y', 'z'), np.ones((3, 4, 2)))}, - {'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]}) - expected = ds.sum('z') - actual = (ds.stack(space=['x', 'y']) - .groupby('space') - .sum('z') - .unstack('space')) + ds = xr.Dataset( + {"foo": (("x", "y", "z"), np.ones((3, 4, 2)))}, + {"x": ["a", "b", "c"], "y": [1, 2, 3, 4]}, + ) + expected = ds.sum("z") + actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space") assert expected.equals(actual) def test_groupby_da_datetime(): # test groupby with a DataArray of dtype datetime for GH1132 # create test data - times = pd.date_range('2000-01-01', periods=4) - foo = xr.DataArray([1, 2, 3, 4], coords=dict(time=times), dims='time') + times = pd.date_range("2000-01-01", periods=4) + foo = xr.DataArray([1, 2, 3, 4], coords=dict(time=times), dims="time") # create test index dd = times.to_pydatetime() reference_dates = [dd[0], dd[2]] labels = reference_dates[0:1] * 2 + reference_dates[1:2] * 2 - ind = xr.DataArray(labels, coords=dict(time=times), dims='time', - name='reference_date') + ind = xr.DataArray( + labels, coords=dict(time=times), dims="time", name="reference_date" + ) g = foo.groupby(ind) - actual = g.sum(dim='time') - expected = xr.DataArray([3, 7], - coords=dict(reference_date=reference_dates), - dims='reference_date') + actual = g.sum(dim="time") + expected = xr.DataArray( + [3, 7], coords=dict(reference_date=reference_dates), dims="reference_date" + ) assert actual.equals(expected) def test_groupby_duplicate_coordinate_labels(): # fix for http://stackoverflow.com/questions/38065129 - array = xr.DataArray([1, 2, 3], [('x', [1, 1, 2])]) - expected = xr.DataArray([3, 3], [('x', [1, 2])]) - actual = array.groupby('x').sum() + array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])]) + expected = xr.DataArray([3, 3], [("x", [1, 2])]) + actual = array.groupby("x").sum() assert expected.equals(actual) def test_groupby_input_mutation(): # regression test for GH2153 - array = xr.DataArray([1, 2, 3], [('x', [2, 2, 1])]) + array = xr.DataArray([1, 2, 3], [("x", [2, 2, 1])]) array_copy = array.copy() - expected = xr.DataArray([3, 3], [('x', [1, 2])]) - actual = array.groupby('x').sum() + expected = xr.DataArray([3, 3], [("x", [1, 2])]) + actual = array.groupby("x").sum() assert_identical(expected, actual) assert_identical(array, array_copy) # should not modify inputs def test_da_groupby_apply_func_args(): - def func(arg1, arg2, arg3=0): return arg1 + arg2 + arg3 - array = xr.DataArray([1, 1, 1], [('x', [1, 2, 3])]) - expected = xr.DataArray([3, 3, 3], [('x', [1, 2, 3])]) - actual = array.groupby('x').apply(func, args=(1,), arg3=1) + array = xr.DataArray([1, 1, 1], [("x", [1, 2, 3])]) + expected = xr.DataArray([3, 3, 3], [("x", [1, 2, 3])]) + actual = array.groupby("x").apply(func, args=(1,), arg3=1) assert_identical(expected, actual) def test_ds_groupby_apply_func_args(): - def func(arg1, arg2, arg3=0): return arg1 + arg2 + arg3 - dataset = xr.Dataset({'foo': ('x', [1, 1, 1])}, {'x': [1, 2, 3]}) - expected = xr.Dataset({'foo': ('x', [3, 3, 3])}, {'x': [1, 2, 3]}) - actual = dataset.groupby('x').apply(func, args=(1,), arg3=1) + dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) + expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]}) + actual = dataset.groupby("x").apply(func, args=(1,), arg3=1) assert_identical(expected, actual) def test_da_groupby_empty(): - empty_array = xr.DataArray([], dims='dim') + empty_array = xr.DataArray([], dims="dim") with pytest.raises(ValueError): - empty_array.groupby('dim') + empty_array.groupby("dim") def test_da_groupby_quantile(): - array = xr.DataArray([1, 2, 3, 4, 5, 6], - [('x', [1, 1, 1, 2, 2, 2])]) + array = xr.DataArray([1, 2, 3, 4, 5, 6], [("x", [1, 1, 1, 2, 2, 2])]) # Scalar quantile - expected = xr.DataArray([2, 5], [('x', [1, 2])]) - actual = array.groupby('x').quantile(.5) + expected = xr.DataArray([2, 5], [("x", [1, 2])]) + actual = array.groupby("x").quantile(0.5) assert_identical(expected, actual) # Vector quantile - expected = xr.DataArray([[1, 3], [4, 6]], - [('x', [1, 2]), ('quantile', [0, 1])]) - actual = array.groupby('x').quantile([0, 1]) + expected = xr.DataArray([[1, 3], [4, 6]], [("x", [1, 2]), ("quantile", [0, 1])]) + actual = array.groupby("x").quantile([0, 1]) assert_identical(expected, actual) # Multiple dimensions - array = xr.DataArray([[1, 11, 26], [2, 12, 22], [3, 13, 23], - [4, 16, 24], [5, 15, 25]], - [('x', [1, 1, 1, 2, 2],), - ('y', [0, 0, 1])]) - - actual_x = array.groupby('x').quantile(0) - expected_x = xr.DataArray([1, 4], - [('x', [1, 2]), ]) + array = xr.DataArray( + [[1, 11, 26], [2, 12, 22], [3, 13, 23], [4, 16, 24], [5, 15, 25]], + [("x", [1, 1, 1, 2, 2]), ("y", [0, 0, 1])], + ) + + actual_x = array.groupby("x").quantile(0) + expected_x = xr.DataArray([1, 4], [("x", [1, 2])]) assert_identical(expected_x, actual_x) - actual_y = array.groupby('y').quantile(0) - expected_y = xr.DataArray([1, 22], - [('y', [0, 1]), ]) + actual_y = array.groupby("y").quantile(0) + expected_y = xr.DataArray([1, 22], [("y", [0, 1])]) assert_identical(expected_y, actual_y) - actual_xx = array.groupby('x').quantile(0, dim='x') - expected_xx = xr.DataArray([[1, 11, 22], [4, 15, 24]], - [('x', [1, 2]), ('y', [0, 0, 1])]) + actual_xx = array.groupby("x").quantile(0, dim="x") + expected_xx = xr.DataArray( + [[1, 11, 22], [4, 15, 24]], [("x", [1, 2]), ("y", [0, 0, 1])] + ) assert_identical(expected_xx, actual_xx) - actual_yy = array.groupby('y').quantile(0, dim='y') - expected_yy = xr.DataArray([[1, 26], [2, 22], [3, 23], [4, 24], [5, 25]], - [('x', [1, 1, 1, 2, 2]), ('y', [0, 1])]) + actual_yy = array.groupby("y").quantile(0, dim="y") + expected_yy = xr.DataArray( + [[1, 26], [2, 22], [3, 23], [4, 24], [5, 25]], + [("x", [1, 1, 1, 2, 2]), ("y", [0, 1])], + ) assert_identical(expected_yy, actual_yy) - times = pd.date_range('2000-01-01', periods=365) + times = pd.date_range("2000-01-01", periods=365) x = [0, 1] - foo = xr.DataArray(np.reshape(np.arange(365 * 2), (365, 2)), - coords=dict(time=times, x=x), dims=('time', 'x')) + foo = xr.DataArray( + np.reshape(np.arange(365 * 2), (365, 2)), + coords=dict(time=times, x=x), + dims=("time", "x"), + ) g = foo.groupby(foo.time.dt.month) actual = g.quantile(0) - expected = xr.DataArray([0., 62., 120., 182., 242., 304., - 364., 426., 488., 548., 610., 670.], - [('month', np.arange(1, 13))]) + expected = xr.DataArray( + [ + 0.0, + 62.0, + 120.0, + 182.0, + 242.0, + 304.0, + 364.0, + 426.0, + 488.0, + 548.0, + 610.0, + 670.0, + ], + [("month", np.arange(1, 13))], + ) assert_identical(expected, actual) - actual = g.quantile(0, dim='time')[:2] - expected = xr.DataArray([[0., 1], [62., 63]], - [('month', [1, 2]), ('x', [0, 1])]) + actual = g.quantile(0, dim="time")[:2] + expected = xr.DataArray([[0.0, 1], [62.0, 63]], [("month", [1, 2]), ("x", [0, 1])]) assert_identical(expected, actual) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 64eee80d4eb..f37f8d98ca8 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -22,21 +22,29 @@ def test_expanded_indexer(self): x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5) I = ReturnItem() # noqa - for i in [I[:], I[...], I[0, :, 10], I[..., 10], I[:5, ..., 0], - I[..., 0, :], I[y], I[y, y], I[..., y, y], - I[..., 0, 1, 2, 3, 4]]: + for i in [ + I[:], + I[...], + I[0, :, 10], + I[..., 10], + I[:5, ..., 0], + I[..., 0, :], + I[y], + I[y, y], + I[..., y, y], + I[..., 0, 1, 2, 3, 4], + ]: j = indexing.expanded_indexer(i, x.ndim) assert_array_equal(x[i], x[j]) - assert_array_equal(self.set_to_zero(x, i), - self.set_to_zero(x, j)) - with raises_regex(IndexError, 'too many indices'): + assert_array_equal(self.set_to_zero(x, i), self.set_to_zero(x, j)) + with raises_regex(IndexError, "too many indices"): indexing.expanded_indexer(I[1, 2, 3], 2) def test_asarray_tuplesafe(self): - res = indexing._asarray_tuplesafe(('a', 1)) + res = indexing._asarray_tuplesafe(("a", 1)) assert isinstance(res, np.ndarray) assert res.ndim == 0 - assert res.item() == ('a', 1) + assert res.item() == ("a", 1) res = indexing._asarray_tuplesafe([(0,), (1,)]) assert res.shape == (2,) @@ -46,112 +54,131 @@ def test_asarray_tuplesafe(self): def test_stacked_multiindex_min_max(self): data = np.random.randn(3, 23, 4) da = DataArray( - data, name="value", + data, + name="value", dims=["replicate", "rsample", "exp"], coords=dict( - replicate=[0, 1, 2], - exp=["a", "b", "c", "d"], - rsample=list(range(23)) + replicate=[0, 1, 2], exp=["a", "b", "c", "d"], rsample=list(range(23)) ), ) da2 = da.stack(sample=("replicate", "rsample")) s = da2.sample - assert_array_equal(da2.loc['a', s.max()], data[2, 22, 0]) - assert_array_equal(da2.loc['b', s.min()], data[0, 0, 1]) + assert_array_equal(da2.loc["a", s.max()], data[2, 22, 0]) + assert_array_equal(da2.loc["b", s.min()], data[0, 0, 1]) def test_convert_label_indexer(self): # TODO: add tests that aren't just for edge cases index = pd.Index([1, 2, 3]) - with raises_regex(KeyError, 'not all values found'): + with raises_regex(KeyError, "not all values found"): indexing.convert_label_indexer(index, [0]) with pytest.raises(KeyError): indexing.convert_label_indexer(index, 0) - with raises_regex(ValueError, 'does not have a MultiIndex'): - indexing.convert_label_indexer(index, {'one': 0}) + with raises_regex(ValueError, "does not have a MultiIndex"): + indexing.convert_label_indexer(index, {"one": 0}) - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=('one', 'two')) - with raises_regex(KeyError, 'not all values found'): + mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + with raises_regex(KeyError, "not all values found"): indexing.convert_label_indexer(mindex, [0]) with pytest.raises(KeyError): indexing.convert_label_indexer(mindex, 0) with pytest.raises(ValueError): - indexing.convert_label_indexer(index, {'three': 0}) + indexing.convert_label_indexer(index, {"three": 0}) with pytest.raises((KeyError, IndexError)): # pandas 0.21 changed this from KeyError to IndexError - indexing.convert_label_indexer( - mindex, (slice(None), 1, 'no_level')) + indexing.convert_label_indexer(mindex, (slice(None), 1, "no_level")) def test_convert_unsorted_datetime_index_raises(self): - index = pd.to_datetime(['2001', '2000', '2002']) + index = pd.to_datetime(["2001", "2000", "2002"]) with pytest.raises(KeyError): # pandas will try to convert this into an array indexer. We should # raise instead, so we can be sure the result of indexing with a # slice is always a view. - indexing.convert_label_indexer(index, slice('2001', '2002')) + indexing.convert_label_indexer(index, slice("2001", "2002")) def test_get_dim_indexers(self): - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=('one', 'two')) - mdata = DataArray(range(4), [('x', mindex)]) + mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + mdata = DataArray(range(4), [("x", mindex)]) - dim_indexers = indexing.get_dim_indexers(mdata, {'one': 'a', 'two': 1}) - assert dim_indexers == {'x': {'one': 'a', 'two': 1}} + dim_indexers = indexing.get_dim_indexers(mdata, {"one": "a", "two": 1}) + assert dim_indexers == {"x": {"one": "a", "two": 1}} - with raises_regex(ValueError, 'cannot combine'): - indexing.get_dim_indexers(mdata, {'x': 'a', 'two': 1}) + with raises_regex(ValueError, "cannot combine"): + indexing.get_dim_indexers(mdata, {"x": "a", "two": 1}) - with raises_regex(ValueError, 'do not exist'): - indexing.get_dim_indexers(mdata, {'y': 'a'}) + with raises_regex(ValueError, "do not exist"): + indexing.get_dim_indexers(mdata, {"y": "a"}) - with raises_regex(ValueError, 'do not exist'): - indexing.get_dim_indexers(mdata, {'four': 1}) + with raises_regex(ValueError, "do not exist"): + indexing.get_dim_indexers(mdata, {"four": 1}) def test_remap_label_indexers(self): def test_indexer(data, x, expected_pos, expected_idx=None): - pos, idx = indexing.remap_label_indexers(data, {'x': x}) - assert_array_equal(pos.get('x'), expected_pos) - assert_array_equal(idx.get('x'), expected_idx) + pos, idx = indexing.remap_label_indexers(data, {"x": x}) + assert_array_equal(pos.get("x"), expected_pos) + assert_array_equal(idx.get("x"), expected_idx) - data = Dataset({'x': ('x', [1, 2, 3])}) - mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], - names=('one', 'two', 'three')) - mdata = DataArray(range(8), [('x', mindex)]) + data = Dataset({"x": ("x", [1, 2, 3])}) + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + mdata = DataArray(range(8), [("x", mindex)]) test_indexer(data, 1, 0) test_indexer(data, np.int32(1), 0) test_indexer(data, Variable([], 1), 0) - test_indexer(mdata, ('a', 1, -1), 0) - test_indexer(mdata, ('a', 1), - [True, True, False, False, False, False, False, False], - [-1, -2]) - test_indexer(mdata, 'a', slice(0, 4, None), - pd.MultiIndex.from_product([[1, 2], [-1, -2]])) - test_indexer(mdata, ('a',), - [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]])) - test_indexer(mdata, [('a', 1, -1), ('b', 2, -2)], [0, 7]) - test_indexer(mdata, slice('a', 'b'), slice(0, 8, None)) - test_indexer(mdata, slice(('a', 1), ('b', 1)), slice(0, 6, None)) - test_indexer(mdata, {'one': 'a', 'two': 1, 'three': -1}, 0) - test_indexer(mdata, {'one': 'a', 'two': 1}, - [True, True, False, False, False, False, False, False], - [-1, -2]) - test_indexer(mdata, {'one': 'a', 'three': -1}, - [True, False, True, False, False, False, False, False], - [1, 2]) - test_indexer(mdata, {'one': 'a'}, - [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]])) + test_indexer(mdata, ("a", 1, -1), 0) + test_indexer( + mdata, + ("a", 1), + [True, True, False, False, False, False, False, False], + [-1, -2], + ) + test_indexer( + mdata, + "a", + slice(0, 4, None), + pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + ) + test_indexer( + mdata, + ("a",), + [True, True, True, True, False, False, False, False], + pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + ) + test_indexer(mdata, [("a", 1, -1), ("b", 2, -2)], [0, 7]) + test_indexer(mdata, slice("a", "b"), slice(0, 8, None)) + test_indexer(mdata, slice(("a", 1), ("b", 1)), slice(0, 6, None)) + test_indexer(mdata, {"one": "a", "two": 1, "three": -1}, 0) + test_indexer( + mdata, + {"one": "a", "two": 1}, + [True, True, False, False, False, False, False, False], + [-1, -2], + ) + test_indexer( + mdata, + {"one": "a", "three": -1}, + [True, False, True, False, False, False, False, False], + [1, 2], + ) + test_indexer( + mdata, + {"one": "a"}, + [True, True, True, True, False, False, False, False], + pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + ) def test_read_only_view(self): from collections import OrderedDict - arr = DataArray(np.random.rand(3, 3), - coords={'x': np.arange(3), 'y': np.arange(3)}, - dims=('x', 'y')) # Create a 2D DataArray - arr = arr.expand_dims(OrderedDict([('z', 3)]), -1) # New dimension 'z' - arr['z'] = np.arange(3) # New coords to dimension 'z' - with pytest.raises(ValueError, match='Do you want to .copy()'): + + arr = DataArray( + np.random.rand(3, 3), + coords={"x": np.arange(3), "y": np.arange(3)}, + dims=("x", "y"), + ) # Create a 2D DataArray + arr = arr.expand_dims(OrderedDict([("z", 3)]), -1) # New dimension 'z' + arr["z"] = np.arange(3) # New coords to dimension 'z' + with pytest.raises(ValueError, match="Do you want to .copy()"): arr.loc[0, 0, 0] = 999 @@ -161,9 +188,24 @@ def test_slice_slice(self): for size in [100, 99]: # We test even/odd size cases x = np.arange(size) - slices = [I[:3], I[:4], I[2:4], I[:1], I[:-1], I[5:-1], I[-5:-1], - I[::-1], I[5::-1], I[:3:-1], I[:30:-1], I[10:4:], I[::4], - I[4:4:4], I[:4:-4], I[::-2]] + slices = [ + I[:3], + I[:4], + I[2:4], + I[:1], + I[:-1], + I[5:-1], + I[-5:-1], + I[::-1], + I[5::-1], + I[:3:-1], + I[:30:-1], + I[10:4:], + I[::4], + I[4:4:4], + I[:4:-4], + I[::-2], + ] for i in slices: for j in slices: expected = x[i][j] @@ -174,40 +216,49 @@ def test_slice_slice(self): def test_lazily_indexed_array(self): original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) - v = Variable(['i', 'j', 'k'], original) + v = Variable(["i", "j", "k"], original) lazy = indexing.LazilyOuterIndexedArray(x) - v_lazy = Variable(['i', 'j', 'k'], lazy) + v_lazy = Variable(["i", "j", "k"], lazy) I = ReturnItem() # noqa: E741 # allow ambiguous name # test orthogonally applied indexers indexers = [I[:], 0, -2, I[:3], [0, 1, 2, 3], [0], np.arange(10) < 5] for i in indexers: for j in indexers: for k in indexers: - if isinstance(j, np.ndarray) and j.dtype.kind == 'b': + if isinstance(j, np.ndarray) and j.dtype.kind == "b": j = np.arange(20) < 5 - if isinstance(k, np.ndarray) and k.dtype.kind == 'b': + if isinstance(k, np.ndarray) and k.dtype.kind == "b": k = np.arange(30) < 5 expected = np.asarray(v[i, j, k]) - for actual in [v_lazy[i, j, k], - v_lazy[:, j, k][i], - v_lazy[:, :, k][:, j][i]]: + for actual in [ + v_lazy[i, j, k], + v_lazy[:, j, k][i], + v_lazy[:, :, k][:, j][i], + ]: assert expected.shape == actual.shape assert_array_equal(expected, actual) - assert isinstance(actual._data, - indexing.LazilyOuterIndexedArray) + assert isinstance( + actual._data, indexing.LazilyOuterIndexedArray + ) # make sure actual.key is appropriate type - if all(isinstance(k, (int, slice, )) - for k in v_lazy._data.key.tuple): - assert isinstance(v_lazy._data.key, - indexing.BasicIndexer) + if all( + isinstance(k, (int, slice)) for k in v_lazy._data.key.tuple + ): + assert isinstance(v_lazy._data.key, indexing.BasicIndexer) else: - assert isinstance(v_lazy._data.key, - indexing.OuterIndexer) + assert isinstance(v_lazy._data.key, indexing.OuterIndexer) # test sequentially applied indexers - indexers = [(3, 2), (I[:], 0), (I[:2], -1), (I[:4], [0]), ([4, 5], 0), - ([0, 1, 2], [0, 1]), ([0, 3, 5], I[:2])] + indexers = [ + (3, 2), + (I[:], 0), + (I[:2], -1), + (I[:4], [0]), + ([4, 5], 0), + ([0, 1, 2], [0, 1]), + ([0, 3, 5], I[:2]), + ] for i, j in indexers: expected = v[i][j] actual = v_lazy[i][j] @@ -221,19 +272,22 @@ def test_lazily_indexed_array(self): transposed = actual.transpose(*order) assert_array_equal(expected.transpose(*order), transposed) assert isinstance( - actual._data, (indexing.LazilyVectorizedIndexedArray, - indexing.LazilyOuterIndexedArray)) + actual._data, + ( + indexing.LazilyVectorizedIndexedArray, + indexing.LazilyOuterIndexedArray, + ), + ) assert isinstance(actual._data, indexing.LazilyOuterIndexedArray) - assert isinstance(actual._data.array, - indexing.NumpyIndexingAdapter) + assert isinstance(actual._data.array, indexing.NumpyIndexingAdapter) def test_vectorized_lazily_indexed_array(self): original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) - v_eager = Variable(['i', 'j', 'k'], x) + v_eager = Variable(["i", "j", "k"], x) lazy = indexing.LazilyOuterIndexedArray(x) - v_lazy = Variable(['i', 'j', 'k'], lazy) + v_lazy = Variable(["i", "j", "k"], lazy) I = ReturnItem() # noqa: E741 # allow ambiguous name def check_indexing(v_eager, v_lazy, indexers): @@ -241,32 +295,39 @@ def check_indexing(v_eager, v_lazy, indexers): actual = v_lazy[indexer] expected = v_eager[indexer] assert expected.shape == actual.shape - assert isinstance(actual._data, - (indexing.LazilyVectorizedIndexedArray, - indexing.LazilyOuterIndexedArray)) + assert isinstance( + actual._data, + ( + indexing.LazilyVectorizedIndexedArray, + indexing.LazilyOuterIndexedArray, + ), + ) assert_array_equal(expected, actual) v_eager = expected v_lazy = actual # test orthogonal indexing - indexers = [(I[:], 0, 1), (Variable('i', [0, 1]), )] + indexers = [(I[:], 0, 1), (Variable("i", [0, 1]),)] check_indexing(v_eager, v_lazy, indexers) # vectorized indexing indexers = [ - (Variable('i', [0, 1]), Variable('i', [0, 1]), slice(None)), - (slice(1, 3, 2), 0)] + (Variable("i", [0, 1]), Variable("i", [0, 1]), slice(None)), + (slice(1, 3, 2), 0), + ] check_indexing(v_eager, v_lazy, indexers) indexers = [ (slice(None, None, 2), 0, slice(None, 10)), - (Variable('i', [3, 2, 4, 3]), Variable('i', [3, 2, 1, 0])), - (Variable(['i', 'j'], [[0, 1], [1, 2]]), )] + (Variable("i", [3, 2, 4, 3]), Variable("i", [3, 2, 1, 0])), + (Variable(["i", "j"], [[0, 1], [1, 2]]),), + ] check_indexing(v_eager, v_lazy, indexers) indexers = [ - (Variable('i', [3, 2, 4, 3]), Variable('i', [3, 2, 1, 0])), - (Variable(['i', 'j'], [[0, 1], [1, 2]]), )] + (Variable("i", [3, 2, 4, 3]), Variable("i", [3, 2, 1, 0])), + (Variable(["i", "j"], [[0, 1], [1, 2]]),), + ] check_indexing(v_eager, v_lazy, indexers) @@ -290,8 +351,8 @@ def test_sub_array(self): def test_index_scalar(self): # regression test for GH1374 - x = indexing.CopyOnWriteArray(np.array(['foo', 'bar'])) - assert np.array(x[B[0]][B[()]]) == 'foo' + x = indexing.CopyOnWriteArray(np.array(["foo", "bar"])) + assert np.array(x[B[0]][B[()]]) == "foo" class TestMemoryCachedArray: @@ -318,8 +379,8 @@ def test_setitem(self): def test_index_scalar(self): # regression test for GH1374 - x = indexing.MemoryCachedArray(np.array(['foo', 'bar'])) - assert np.array(x[B[0]][B[()]]) == 'foo' + x = indexing.MemoryCachedArray(np.array(["foo", "bar"])) + assert np.array(x[B[0]][B[()]]) == "foo" def test_base_explicit_indexer(): @@ -331,12 +392,13 @@ class Subclass(indexing.ExplicitIndexer): value = Subclass((1, 2, 3)) assert value.tuple == (1, 2, 3) - assert repr(value) == 'Subclass((1, 2, 3))' + assert repr(value) == "Subclass((1, 2, 3))" -@pytest.mark.parametrize('indexer_cls', [indexing.BasicIndexer, - indexing.OuterIndexer, - indexing.VectorizedIndexer]) +@pytest.mark.parametrize( + "indexer_cls", + [indexing.BasicIndexer, indexing.OuterIndexer, indexing.VectorizedIndexer], +) def test_invalid_for_all(indexer_cls): with pytest.raises(TypeError): indexer_cls(None) @@ -345,17 +407,17 @@ def test_invalid_for_all(indexer_cls): with pytest.raises(TypeError): indexer_cls((None,)) with pytest.raises(TypeError): - indexer_cls(('foo',)) + indexer_cls(("foo",)) with pytest.raises(TypeError): indexer_cls((1.0,)) with pytest.raises(TypeError): - indexer_cls((slice('foo'),)) + indexer_cls((slice("foo"),)) with pytest.raises(TypeError): - indexer_cls((np.array(['foo']),)) + indexer_cls((np.array(["foo"]),)) def check_integer(indexer_cls): - value = indexer_cls((1, np.uint64(2),)).tuple + value = indexer_cls((1, np.uint64(2))).tuple assert all(isinstance(v, int) for v in value) assert value == (1, 2) @@ -402,97 +464,104 @@ def test_vectorized_indexer(): check_slice(indexing.VectorizedIndexer) check_array1d(indexing.VectorizedIndexer) check_array2d(indexing.VectorizedIndexer) - with raises_regex(ValueError, 'numbers of dimensions'): - indexing.VectorizedIndexer((np.array(1, dtype=np.int64), - np.arange(5, dtype=np.int64))) + with raises_regex(ValueError, "numbers of dimensions"): + indexing.VectorizedIndexer( + (np.array(1, dtype=np.int64), np.arange(5, dtype=np.int64)) + ) class Test_vectorized_indexer: @pytest.fixture(autouse=True) def setup(self): self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13)) - self.indexers = [np.array([[0, 3, 2], ]), - np.array([[0, 3, 3], [4, 6, 7]]), - slice(2, -2, 2), slice(2, -2, 3), slice(None)] + self.indexers = [ + np.array([[0, 3, 2]]), + np.array([[0, 3, 3], [4, 6, 7]]), + slice(2, -2, 2), + slice(2, -2, 3), + slice(None), + ] def test_arrayize_vectorized_indexer(self): for i, j, k in itertools.product(self.indexers, repeat=3): vindex = indexing.VectorizedIndexer((i, j, k)) vindex_array = indexing._arrayize_vectorized_indexer( - vindex, self.data.shape) - np.testing.assert_array_equal( - self.data[vindex], self.data[vindex_array],) + vindex, self.data.shape + ) + np.testing.assert_array_equal(self.data[vindex], self.data[vindex_array]) actual = indexing._arrayize_vectorized_indexer( - indexing.VectorizedIndexer((slice(None),)), shape=(5,)) + indexing.VectorizedIndexer((slice(None),)), shape=(5,) + ) np.testing.assert_array_equal(actual.tuple, [np.arange(5)]) actual = indexing._arrayize_vectorized_indexer( - indexing.VectorizedIndexer((np.arange(5),) * 3), shape=(8, 10, 12)) + indexing.VectorizedIndexer((np.arange(5),) * 3), shape=(8, 10, 12) + ) expected = np.stack([np.arange(5)] * 3) np.testing.assert_array_equal(np.stack(actual.tuple), expected) actual = indexing._arrayize_vectorized_indexer( - indexing.VectorizedIndexer((np.arange(5), slice(None))), - shape=(8, 10)) + indexing.VectorizedIndexer((np.arange(5), slice(None))), shape=(8, 10) + ) a, b = actual.tuple np.testing.assert_array_equal(a, np.arange(5)[:, np.newaxis]) np.testing.assert_array_equal(b, np.arange(10)[np.newaxis, :]) actual = indexing._arrayize_vectorized_indexer( - indexing.VectorizedIndexer((slice(None), np.arange(5))), - shape=(8, 10)) + indexing.VectorizedIndexer((slice(None), np.arange(5))), shape=(8, 10) + ) a, b = actual.tuple np.testing.assert_array_equal(a, np.arange(8)[np.newaxis, :]) np.testing.assert_array_equal(b, np.arange(5)[:, np.newaxis]) def get_indexers(shape, mode): - if mode == 'vectorized': + if mode == "vectorized": indexed_shape = (3, 4) - indexer = tuple(np.random.randint(0, s, size=indexed_shape) - for s in shape) + indexer = tuple(np.random.randint(0, s, size=indexed_shape) for s in shape) return indexing.VectorizedIndexer(indexer) - elif mode == 'outer': + elif mode == "outer": indexer = tuple(np.random.randint(0, s, s + 2) for s in shape) return indexing.OuterIndexer(indexer) - elif mode == 'outer_scalar': + elif mode == "outer_scalar": indexer = (np.random.randint(0, 3, 4), 0, slice(None, None, 2)) - return indexing.OuterIndexer(indexer[:len(shape)]) + return indexing.OuterIndexer(indexer[: len(shape)]) - elif mode == 'outer_scalar2': + elif mode == "outer_scalar2": indexer = (np.random.randint(0, 3, 4), -2, slice(None, None, 2)) - return indexing.OuterIndexer(indexer[:len(shape)]) + return indexing.OuterIndexer(indexer[: len(shape)]) - elif mode == 'outer1vec': + elif mode == "outer1vec": indexer = [slice(2, -3) for s in shape] indexer[1] = np.random.randint(0, shape[1], shape[1] + 2) return indexing.OuterIndexer(tuple(indexer)) - elif mode == 'basic': # basic indexer + elif mode == "basic": # basic indexer indexer = [slice(2, -3) for s in shape] indexer[0] = 3 return indexing.BasicIndexer(tuple(indexer)) - elif mode == 'basic1': # basic indexer - return indexing.BasicIndexer((3, )) + elif mode == "basic1": # basic indexer + return indexing.BasicIndexer((3,)) - elif mode == 'basic2': # basic indexer + elif mode == "basic2": # basic indexer indexer = [0, 2, 4] - return indexing.BasicIndexer(tuple(indexer[:len(shape)])) + return indexing.BasicIndexer(tuple(indexer[: len(shape)])) - elif mode == 'basic3': # basic indexer + elif mode == "basic3": # basic indexer indexer = [slice(None) for s in shape] indexer[0] = slice(-2, 2, -2) indexer[1] = slice(1, -1, 2) - return indexing.BasicIndexer(tuple(indexer[:len(shape)])) + return indexing.BasicIndexer(tuple(indexer[: len(shape)])) -@pytest.mark.parametrize('size', [100, 99]) -@pytest.mark.parametrize('sl', [slice(1, -1, 1), slice(None, -1, 2), - slice(-1, 1, -1), slice(-1, 1, -2)]) +@pytest.mark.parametrize("size", [100, 99]) +@pytest.mark.parametrize( + "sl", [slice(1, -1, 1), slice(None, -1, 2), slice(-1, 1, -1), slice(-1, 1, -2)] +) def test_decompose_slice(size, sl): x = np.arange(size) slice1, slice2 = indexing._decompose_slice(sl, size) @@ -501,22 +570,35 @@ def test_decompose_slice(size, sl): assert_array_equal(expected, actual) -@pytest.mark.parametrize('shape', [(10, 5, 8), (10, 3)]) -@pytest.mark.parametrize('indexer_mode', - ['vectorized', 'outer', 'outer_scalar', - 'outer_scalar2', 'outer1vec', - 'basic', 'basic1', 'basic2', 'basic3']) -@pytest.mark.parametrize('indexing_support', - [indexing.IndexingSupport.BASIC, - indexing.IndexingSupport.OUTER, - indexing.IndexingSupport.OUTER_1VECTOR, - indexing.IndexingSupport.VECTORIZED]) +@pytest.mark.parametrize("shape", [(10, 5, 8), (10, 3)]) +@pytest.mark.parametrize( + "indexer_mode", + [ + "vectorized", + "outer", + "outer_scalar", + "outer_scalar2", + "outer1vec", + "basic", + "basic1", + "basic2", + "basic3", + ], +) +@pytest.mark.parametrize( + "indexing_support", + [ + indexing.IndexingSupport.BASIC, + indexing.IndexingSupport.OUTER, + indexing.IndexingSupport.OUTER_1VECTOR, + indexing.IndexingSupport.VECTORIZED, + ], +) def test_decompose_indexers(shape, indexer_mode, indexing_support): data = np.random.randn(*shape) indexer = get_indexers(shape, indexer_mode) - backend_ind, np_ind = indexing.decompose_indexer( - indexer, shape, indexing_support) + backend_ind, np_ind = indexing.decompose_indexer(indexer, shape, indexing_support) expected = indexing.NumpyIndexingAdapter(data)[indexer] array = indexing.NumpyIndexingAdapter(data)[backend_ind] @@ -533,7 +615,8 @@ def test_decompose_indexers(shape, indexer_mode, indexing_support): def test_implicit_indexing_adapter(): array = np.arange(10, dtype=np.int64) implicit = indexing.ImplicitToExplicitIndexingAdapter( - indexing.NumpyIndexingAdapter(array), indexing.BasicIndexer) + indexing.NumpyIndexingAdapter(array), indexing.BasicIndexer + ) np.testing.assert_array_equal(array, np.asarray(implicit)) np.testing.assert_array_equal(array, implicit[:]) @@ -541,38 +624,44 @@ def test_implicit_indexing_adapter(): def test_implicit_indexing_adapter_copy_on_write(): array = np.arange(10, dtype=np.int64) implicit = indexing.ImplicitToExplicitIndexingAdapter( - indexing.CopyOnWriteArray(array)) + indexing.CopyOnWriteArray(array) + ) assert isinstance(implicit[:], indexing.ImplicitToExplicitIndexingAdapter) def test_outer_indexer_consistency_with_broadcast_indexes_vectorized(): def nonzero(x): - if isinstance(x, np.ndarray) and x.dtype.kind == 'b': + if isinstance(x, np.ndarray) and x.dtype.kind == "b": x = x.nonzero()[0] return x original = np.random.rand(10, 20, 30) - v = Variable(['i', 'j', 'k'], original) + v = Variable(["i", "j", "k"], original) I = ReturnItem() # noqa: E741 # allow ambiguous name # test orthogonally applied indexers - indexers = [I[:], 0, -2, I[:3], np.array([0, 1, 2, 3]), np.array([0]), - np.arange(10) < 5] + indexers = [ + I[:], + 0, + -2, + I[:3], + np.array([0, 1, 2, 3]), + np.array([0]), + np.arange(10) < 5, + ] for i, j, k in itertools.product(indexers, repeat=3): - if isinstance(j, np.ndarray) and j.dtype.kind == 'b': # match size + if isinstance(j, np.ndarray) and j.dtype.kind == "b": # match size j = np.arange(20) < 4 - if isinstance(k, np.ndarray) and k.dtype.kind == 'b': + if isinstance(k, np.ndarray) and k.dtype.kind == "b": k = np.arange(30) < 8 _, expected, new_order = v._broadcast_indexes_vectorized((i, j, k)) expected_data = nputils.NumpyVIndexAdapter(v.data)[expected.tuple] if new_order: old_order = range(len(new_order)) - expected_data = np.moveaxis(expected_data, old_order, - new_order) + expected_data = np.moveaxis(expected_data, old_order, new_order) - outer_index = indexing.OuterIndexer((nonzero(i), nonzero(j), - nonzero(k))) + outer_index = indexing.OuterIndexer((nonzero(i), nonzero(j), nonzero(k))) actual = indexing._outer_to_numpy_indexer(outer_index, v.shape) actual_data = v.data[actual] np.testing.assert_array_equal(actual_data, expected_data) @@ -584,21 +673,21 @@ def test_create_mask_outer_indexer(): actual = indexing.create_mask(indexer, (5,)) np.testing.assert_array_equal(expected, actual) - indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]),)) + indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]))) expected = np.array(2 * [[False, True, False]]) - actual = indexing.create_mask(indexer, (5, 5, 5,)) + actual = indexing.create_mask(indexer, (5, 5, 5)) np.testing.assert_array_equal(expected, actual) def test_create_mask_vectorized_indexer(): - indexer = indexing.VectorizedIndexer( - (np.array([0, -1, 2]), np.array([0, 1, -1]))) + indexer = indexing.VectorizedIndexer((np.array([0, -1, 2]), np.array([0, 1, -1]))) expected = np.array([False, True, True]) actual = indexing.create_mask(indexer, (5,)) np.testing.assert_array_equal(expected, actual) indexer = indexing.VectorizedIndexer( - (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1]))) + (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])) + ) expected = np.array([[False, True, True]] * 2).T actual = indexing.create_mask(indexer, (5, 2)) np.testing.assert_array_equal(expected, actual) @@ -615,17 +704,17 @@ def test_create_mask_basic_indexer(): def test_create_mask_dask(): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") - indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]),)) + indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]))) expected = np.array(2 * [[False, True, False]]) - actual = indexing.create_mask(indexer, (5, 5, 5,), - chunks_hint=((1, 1), (2, 1))) + actual = indexing.create_mask(indexer, (5, 5, 5), chunks_hint=((1, 1), (2, 1))) assert actual.chunks == ((1, 1), (2, 1)) np.testing.assert_array_equal(expected, actual) indexer = indexing.VectorizedIndexer( - (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1]))) + (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])) + ) expected = np.array([[False, True, True]] * 2).T actual = indexing.create_mask(indexer, (5, 2), chunks_hint=((3,), (2,))) assert isinstance(actual, da.Array) @@ -636,19 +725,22 @@ def test_create_mask_dask(): def test_create_mask_error(): - with raises_regex(TypeError, 'unexpected key type'): + with raises_regex(TypeError, "unexpected key type"): indexing.create_mask((1, 2), (3, 4)) -@pytest.mark.parametrize('indices, expected', [ - (np.arange(5), np.arange(5)), - (np.array([0, -1, -1]), np.array([0, 0, 0])), - (np.array([-1, 1, -1]), np.array([1, 1, 1])), - (np.array([-1, -1, 2]), np.array([2, 2, 2])), - (np.array([-1]), np.array([0])), - (np.array([0, -1, 1, -1, -1]), np.array([0, 0, 1, 1, 1])), - (np.array([0, -1, -1, -1, 1]), np.array([0, 0, 0, 0, 1])), -]) +@pytest.mark.parametrize( + "indices, expected", + [ + (np.arange(5), np.arange(5)), + (np.array([0, -1, -1]), np.array([0, 0, 0])), + (np.array([-1, 1, -1]), np.array([1, 1, 1])), + (np.array([-1, -1, 2]), np.array([2, 2, 2])), + (np.array([-1]), np.array([0])), + (np.array([0, -1, 1, -1, -1]), np.array([0, 0, 1, 1, 1])), + (np.array([0, -1, -1, -1, 1]), np.array([0, 0, 0, 0, 1])), + ], +) def test_posify_mask_subindexer(indices, expected): actual = indexing._posify_mask_subindexer(indices) np.testing.assert_array_equal(expected, actual) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 252f8bcacd4..b9dc9a71acc 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -3,8 +3,7 @@ import pytest import xarray as xr -from xarray.tests import ( - assert_allclose, assert_equal, requires_cftime, requires_scipy) +from xarray.tests import assert_allclose, assert_equal, requires_cftime, requires_scipy from ..coding.cftimeindex import _parse_array_of_cftime_strings from . import has_dask, has_scipy @@ -20,79 +19,85 @@ def get_example_data(case): x = np.linspace(0, 1, 100) y = np.linspace(0, 0.1, 30) data = xr.DataArray( - np.sin(x[:, np.newaxis]) * np.cos(y), dims=['x', 'y'], - coords={'x': x, 'y': y, 'x2': ('x', x**2)}) + np.sin(x[:, np.newaxis]) * np.cos(y), + dims=["x", "y"], + coords={"x": x, "y": y, "x2": ("x", x ** 2)}, + ) if case == 0: return data elif case == 1: - return data.chunk({'y': 3}) + return data.chunk({"y": 3}) elif case == 2: - return data.chunk({'x': 25, 'y': 3}) + return data.chunk({"x": 25, "y": 3}) elif case == 3: x = np.linspace(0, 1, 100) y = np.linspace(0, 0.1, 30) z = np.linspace(0.1, 0.2, 10) return xr.DataArray( - np.sin(x[:, np.newaxis, np.newaxis]) * np.cos( - y[:, np.newaxis]) * z, - dims=['x', 'y', 'z'], - coords={'x': x, 'y': y, 'x2': ('x', x**2), 'z': z}) + np.sin(x[:, np.newaxis, np.newaxis]) * np.cos(y[:, np.newaxis]) * z, + dims=["x", "y", "z"], + coords={"x": x, "y": y, "x2": ("x", x ** 2), "z": z}, + ) elif case == 4: - return get_example_data(3).chunk({'z': 5}) + return get_example_data(3).chunk({"z": 5}) def test_keywargs(): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") da = get_example_data(0) - assert_equal(da.interp(x=[0.5, 0.8]), da.interp({'x': [0.5, 0.8]})) + assert_equal(da.interp(x=[0.5, 0.8]), da.interp({"x": [0.5, 0.8]})) -@pytest.mark.parametrize('method', ['linear', 'cubic']) -@pytest.mark.parametrize('dim', ['x', 'y']) -@pytest.mark.parametrize('case', [0, 1]) +@pytest.mark.parametrize("method", ["linear", "cubic"]) +@pytest.mark.parametrize("dim", ["x", "y"]) +@pytest.mark.parametrize("case", [0, 1]) def test_interpolate_1d(method, dim, case): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") if not has_dask and case in [1]: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") da = get_example_data(case) xdest = np.linspace(0.0, 0.9, 80) - if dim == 'y' and case == 1: + if dim == "y" and case == 1: with pytest.raises(NotImplementedError): actual = da.interp(method=method, **{dim: xdest}) - pytest.skip('interpolation along chunked dimension is ' - 'not yet supported') + pytest.skip("interpolation along chunked dimension is " "not yet supported") actual = da.interp(method=method, **{dim: xdest}) # scipy interpolation for the reference def func(obj, new_x): return scipy.interpolate.interp1d( - da[dim], obj.data, axis=obj.get_axis_num(dim), bounds_error=False, - fill_value=np.nan, kind=method)(new_x) - - if dim == 'x': - coords = {'x': xdest, 'y': da['y'], 'x2': ('x', func(da['x2'], xdest))} + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + bounds_error=False, + fill_value=np.nan, + kind=method, + )(new_x) + + if dim == "x": + coords = {"x": xdest, "y": da["y"], "x2": ("x", func(da["x2"], xdest))} else: # y - coords = {'x': da['x'], 'y': xdest, 'x2': da['x2']} + coords = {"x": da["x"], "y": xdest, "x2": da["x2"]} - expected = xr.DataArray(func(da, xdest), dims=['x', 'y'], coords=coords) + expected = xr.DataArray(func(da, xdest), dims=["x", "y"], coords=coords) assert_allclose(actual, expected) -@pytest.mark.parametrize('method', ['cubic', 'zero']) +@pytest.mark.parametrize("method", ["cubic", "zero"]) def test_interpolate_1d_methods(method): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") da = get_example_data(0) - dim = 'x' + dim = "x" xdest = np.linspace(0.0, 0.9, 80) actual = da.interp(method=method, **{dim: xdest}) @@ -100,121 +105,153 @@ def test_interpolate_1d_methods(method): # scipy interpolation for the reference def func(obj, new_x): return scipy.interpolate.interp1d( - da[dim], obj.data, axis=obj.get_axis_num(dim), bounds_error=False, - fill_value=np.nan, kind=method)(new_x) - - coords = {'x': xdest, 'y': da['y'], 'x2': ('x', func(da['x2'], xdest))} - expected = xr.DataArray(func(da, xdest), dims=['x', 'y'], coords=coords) + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + bounds_error=False, + fill_value=np.nan, + kind=method, + )(new_x) + + coords = {"x": xdest, "y": da["y"], "x2": ("x", func(da["x2"], xdest))} + expected = xr.DataArray(func(da, xdest), dims=["x", "y"], coords=coords) assert_allclose(actual, expected) -@pytest.mark.parametrize('use_dask', [False, True]) +@pytest.mark.parametrize("use_dask", [False, True]) def test_interpolate_vectorize(use_dask): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") if not has_dask and use_dask: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") # scipy interpolation for the reference def func(obj, dim, new_x): - shape = [s for i, s in enumerate(obj.shape) - if i != obj.get_axis_num(dim)] + shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)] for s in new_x.shape[::-1]: shape.insert(obj.get_axis_num(dim), s) return scipy.interpolate.interp1d( - da[dim], obj.data, axis=obj.get_axis_num(dim), - bounds_error=False, fill_value=np.nan)(new_x).reshape(shape) + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + bounds_error=False, + fill_value=np.nan, + )(new_x).reshape(shape) da = get_example_data(0) if use_dask: - da = da.chunk({'y': 5}) + da = da.chunk({"y": 5}) # xdest is 1d but has different dimension - xdest = xr.DataArray(np.linspace(0.1, 0.9, 30), dims='z', - coords={'z': np.random.randn(30), - 'z2': ('z', np.random.randn(30))}) + xdest = xr.DataArray( + np.linspace(0.1, 0.9, 30), + dims="z", + coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))}, + ) - actual = da.interp(x=xdest, method='linear') + actual = da.interp(x=xdest, method="linear") - expected = xr.DataArray(func(da, 'x', xdest), dims=['z', 'y'], - coords={'z': xdest['z'], 'z2': xdest['z2'], - 'y': da['y'], - 'x': ('z', xdest.values), - 'x2': ('z', func(da['x2'], 'x', xdest))}) - assert_allclose(actual, - expected.transpose('z', 'y', transpose_coords=True)) + expected = xr.DataArray( + func(da, "x", xdest), + dims=["z", "y"], + coords={ + "z": xdest["z"], + "z2": xdest["z2"], + "y": da["y"], + "x": ("z", xdest.values), + "x2": ("z", func(da["x2"], "x", xdest)), + }, + ) + assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True)) # xdest is 2d - xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5), - dims=['z', 'w'], - coords={'z': np.random.randn(6), - 'w': np.random.randn(5), - 'z2': ('z', np.random.randn(6))}) - - actual = da.interp(x=xdest, method='linear') + xdest = xr.DataArray( + np.linspace(0.1, 0.9, 30).reshape(6, 5), + dims=["z", "w"], + coords={ + "z": np.random.randn(6), + "w": np.random.randn(5), + "z2": ("z", np.random.randn(6)), + }, + ) + + actual = da.interp(x=xdest, method="linear") expected = xr.DataArray( - func(da, 'x', xdest), - dims=['z', 'w', 'y'], - coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'], - 'y': da['y'], 'x': (('z', 'w'), xdest), - 'x2': (('z', 'w'), func(da['x2'], 'x', xdest))}) - assert_allclose(actual, - expected.transpose('z', 'w', 'y', transpose_coords=True)) - - -@pytest.mark.parametrize('case', [3, 4]) + func(da, "x", xdest), + dims=["z", "w", "y"], + coords={ + "z": xdest["z"], + "w": xdest["w"], + "z2": xdest["z2"], + "y": da["y"], + "x": (("z", "w"), xdest), + "x2": (("z", "w"), func(da["x2"], "x", xdest)), + }, + ) + assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True)) + + +@pytest.mark.parametrize("case", [3, 4]) def test_interpolate_nd(case): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") if not has_dask and case == 4: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") da = get_example_data(case) # grid -> grid xdest = np.linspace(0.1, 1.0, 11) ydest = np.linspace(0.0, 0.2, 10) - actual = da.interp(x=xdest, y=ydest, method='linear') + actual = da.interp(x=xdest, y=ydest, method="linear") # linear interpolation is separateable - expected = da.interp(x=xdest, method='linear') - expected = expected.interp(y=ydest, method='linear') - assert_allclose(actual.transpose('x', 'y', 'z'), - expected.transpose('x', 'y', 'z')) + expected = da.interp(x=xdest, method="linear") + expected = expected.interp(y=ydest, method="linear") + assert_allclose(actual.transpose("x", "y", "z"), expected.transpose("x", "y", "z")) # grid -> 1d-sample - xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims='y') - ydest = xr.DataArray(np.linspace(0.0, 0.2, 11), dims='y') - actual = da.interp(x=xdest, y=ydest, method='linear') + xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims="y") + ydest = xr.DataArray(np.linspace(0.0, 0.2, 11), dims="y") + actual = da.interp(x=xdest, y=ydest, method="linear") # linear interpolation is separateable expected_data = scipy.interpolate.RegularGridInterpolator( - (da['x'], da['y']), da.transpose('x', 'y', 'z').values, - method='linear', bounds_error=False, - fill_value=np.nan)(np.stack([xdest, ydest], axis=-1)) + (da["x"], da["y"]), + da.transpose("x", "y", "z").values, + method="linear", + bounds_error=False, + fill_value=np.nan, + )(np.stack([xdest, ydest], axis=-1)) expected = xr.DataArray( - expected_data, dims=['y', 'z'], - coords={'z': da['z'], 'y': ydest, 'x': ('y', xdest.values), - 'x2': da['x2'].interp(x=xdest)}) - assert_allclose(actual.transpose('y', 'z'), expected) + expected_data, + dims=["y", "z"], + coords={ + "z": da["z"], + "y": ydest, + "x": ("y", xdest.values), + "x2": da["x2"].interp(x=xdest), + }, + ) + assert_allclose(actual.transpose("y", "z"), expected) # reversed order - actual = da.interp(y=ydest, x=xdest, method='linear') - assert_allclose(actual.transpose('y', 'z'), expected) + actual = da.interp(y=ydest, x=xdest, method="linear") + assert_allclose(actual.transpose("y", "z"), expected) -@pytest.mark.parametrize('method', ['linear']) -@pytest.mark.parametrize('case', [0, 1]) +@pytest.mark.parametrize("method", ["linear"]) +@pytest.mark.parametrize("case", [0, 1]) def test_interpolate_scalar(method, case): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") if not has_dask and case in [1]: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") da = get_example_data(case) xdest = 0.4 @@ -224,22 +261,26 @@ def test_interpolate_scalar(method, case): # scipy interpolation for the reference def func(obj, new_x): return scipy.interpolate.interp1d( - da['x'], obj.data, axis=obj.get_axis_num('x'), bounds_error=False, - fill_value=np.nan)(new_x) - - coords = {'x': xdest, 'y': da['y'], 'x2': func(da['x2'], xdest)} - expected = xr.DataArray(func(da, xdest), dims=['y'], coords=coords) + da["x"], + obj.data, + axis=obj.get_axis_num("x"), + bounds_error=False, + fill_value=np.nan, + )(new_x) + + coords = {"x": xdest, "y": da["y"], "x2": func(da["x2"], xdest)} + expected = xr.DataArray(func(da, xdest), dims=["y"], coords=coords) assert_allclose(actual, expected) -@pytest.mark.parametrize('method', ['linear']) -@pytest.mark.parametrize('case', [3, 4]) +@pytest.mark.parametrize("method", ["linear"]) +@pytest.mark.parametrize("case", [3, 4]) def test_interpolate_nd_scalar(method, case): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") if not has_dask and case in [4]: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") da = get_example_data(case) xdest = 0.4 @@ -248,25 +289,27 @@ def test_interpolate_nd_scalar(method, case): actual = da.interp(x=xdest, y=ydest, method=method) # scipy interpolation for the reference expected_data = scipy.interpolate.RegularGridInterpolator( - (da['x'], da['y']), da.transpose('x', 'y', 'z').values, - method='linear', bounds_error=False, - fill_value=np.nan)(np.stack([xdest, ydest], axis=-1)) - - coords = {'x': xdest, 'y': ydest, 'x2': da['x2'].interp(x=xdest), - 'z': da['z']} - expected = xr.DataArray(expected_data[0], dims=['z'], coords=coords) + (da["x"], da["y"]), + da.transpose("x", "y", "z").values, + method="linear", + bounds_error=False, + fill_value=np.nan, + )(np.stack([xdest, ydest], axis=-1)) + + coords = {"x": xdest, "y": ydest, "x2": da["x2"].interp(x=xdest), "z": da["z"]} + expected = xr.DataArray(expected_data[0], dims=["z"], coords=coords) assert_allclose(actual, expected) -@pytest.mark.parametrize('use_dask', [True, False]) +@pytest.mark.parametrize("use_dask", [True, False]) def test_nans(use_dask): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") - da = xr.DataArray([0, 1, np.nan, 2], dims='x', coords={'x': range(4)}) + da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)}) if not has_dask and use_dask: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") da = da.chunk() actual = da.interp(x=[0.5, 1.5]) @@ -274,18 +317,18 @@ def test_nans(use_dask): assert actual.count() > 0 -@pytest.mark.parametrize('use_dask', [True, False]) +@pytest.mark.parametrize("use_dask", [True, False]) def test_errors(use_dask): if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") # akima and spline are unavailable - da = xr.DataArray([0, 1, np.nan, 2], dims='x', coords={'x': range(4)}) + da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)}) if not has_dask and use_dask: - pytest.skip('dask is not installed in the environment.') + pytest.skip("dask is not installed in the environment.") da = da.chunk() - for method in ['akima', 'spline']: + for method in ["akima", "spline"]: with pytest.raises(ValueError): da.interp(x=[0.5, 1.5], method=method) @@ -295,34 +338,36 @@ def test_errors(use_dask): else: da = get_example_data(0) - result = da.interp(x=[-1, 1, 3], kwargs={'fill_value': 0.0}) + result = da.interp(x=[-1, 1, 3], kwargs={"fill_value": 0.0}) assert not np.isnan(result.values).any() result = da.interp(x=[-1, 1, 3]) assert np.isnan(result.values).any() # invalid method with pytest.raises(ValueError): - da.interp(x=[2, 0], method='boo') + da.interp(x=[2, 0], method="boo") with pytest.raises(ValueError): - da.interp(x=[2, 0], y=2, method='cubic') + da.interp(x=[2, 0], y=2, method="cubic") with pytest.raises(ValueError): - da.interp(y=[2, 0], method='boo') + da.interp(y=[2, 0], method="boo") # object-type DataArray cannot be interpolated - da = xr.DataArray(['a', 'b', 'c'], dims='x', coords={'x': [0, 1, 2]}) + da = xr.DataArray(["a", "b", "c"], dims="x", coords={"x": [0, 1, 2]}) with pytest.raises(TypeError): da.interp(x=0) @requires_scipy def test_dtype(): - ds = xr.Dataset({'var1': ('x', [0, 1, 2]), 'var2': ('x', ['a', 'b', 'c'])}, - coords={'x': [0.1, 0.2, 0.3], 'z': ('x', ['a', 'b', 'c'])}) + ds = xr.Dataset( + {"var1": ("x", [0, 1, 2]), "var2": ("x", ["a", "b", "c"])}, + coords={"x": [0.1, 0.2, 0.3], "z": ("x", ["a", "b", "c"])}, + ) actual = ds.interp(x=[0.15, 0.25]) - assert 'var1' in actual - assert 'var2' not in actual + assert "var1" in actual + assert "var2" not in actual # object array should be dropped - assert 'z' not in actual.coords + assert "z" not in actual.coords @requires_scipy @@ -332,20 +377,21 @@ def test_sorted(): y = np.random.randn(30) z = np.linspace(0.1, 0.2, 10) * 3.0 da = xr.DataArray( - np.cos(x[:, np.newaxis, np.newaxis]) * np.cos( - y[:, np.newaxis]) * z, - dims=['x', 'y', 'z'], - coords={'x': x, 'y': y, 'x2': ('x', x**2), 'z': z}) + np.cos(x[:, np.newaxis, np.newaxis]) * np.cos(y[:, np.newaxis]) * z, + dims=["x", "y", "z"], + coords={"x": x, "y": y, "x2": ("x", x ** 2), "z": z}, + ) x_new = np.linspace(0, 1, 30) y_new = np.linspace(0, 1, 20) - da_sorted = da.sortby('x') - assert_allclose(da.interp(x=x_new), - da_sorted.interp(x=x_new, assume_sorted=True)) - da_sorted = da.sortby(['x', 'y']) - assert_allclose(da.interp(x=x_new, y=y_new), - da_sorted.interp(x=x_new, y=y_new, assume_sorted=True)) + da_sorted = da.sortby("x") + assert_allclose(da.interp(x=x_new), da_sorted.interp(x=x_new, assume_sorted=True)) + da_sorted = da.sortby(["x", "y"]) + assert_allclose( + da.interp(x=x_new, y=y_new), + da_sorted.interp(x=x_new, y=y_new, assume_sorted=True), + ) with pytest.raises(ValueError): da.interp(x=[0, 1, 2], assume_sorted=True) @@ -353,111 +399,111 @@ def test_sorted(): @requires_scipy def test_dimension_wo_coords(): - da = xr.DataArray(np.arange(12).reshape(3, 4), dims=['x', 'y'], - coords={'y': [0, 1, 2, 3]}) + da = xr.DataArray( + np.arange(12).reshape(3, 4), dims=["x", "y"], coords={"y": [0, 1, 2, 3]} + ) da_w_coord = da.copy() - da_w_coord['x'] = np.arange(3) + da_w_coord["x"] = np.arange(3) - assert_equal(da.interp(x=[0.1, 0.2, 0.3]), - da_w_coord.interp(x=[0.1, 0.2, 0.3])) - assert_equal(da.interp(x=[0.1, 0.2, 0.3], y=[0.5]), - da_w_coord.interp(x=[0.1, 0.2, 0.3], y=[0.5])) + assert_equal(da.interp(x=[0.1, 0.2, 0.3]), da_w_coord.interp(x=[0.1, 0.2, 0.3])) + assert_equal( + da.interp(x=[0.1, 0.2, 0.3], y=[0.5]), + da_w_coord.interp(x=[0.1, 0.2, 0.3], y=[0.5]), + ) @requires_scipy def test_dataset(): ds = create_test_data() - ds.attrs['foo'] = 'var' - ds['var1'].attrs['buz'] = 'var2' - new_dim2 = xr.DataArray([0.11, 0.21, 0.31], dims='z') + ds.attrs["foo"] = "var" + ds["var1"].attrs["buz"] = "var2" + new_dim2 = xr.DataArray([0.11, 0.21, 0.31], dims="z") interpolated = ds.interp(dim2=new_dim2) - assert_allclose(interpolated['var1'], ds['var1'].interp(dim2=new_dim2)) - assert interpolated['var3'].equals(ds['var3']) + assert_allclose(interpolated["var1"], ds["var1"].interp(dim2=new_dim2)) + assert interpolated["var3"].equals(ds["var3"]) # make sure modifying interpolated does not affect the original dataset - interpolated['var1'][:, 1] = 1.0 - interpolated['var2'][:, 1] = 1.0 - interpolated['var3'][:, 1] = 1.0 + interpolated["var1"][:, 1] = 1.0 + interpolated["var2"][:, 1] = 1.0 + interpolated["var3"][:, 1] = 1.0 - assert not interpolated['var1'].equals(ds['var1']) - assert not interpolated['var2'].equals(ds['var2']) - assert not interpolated['var3'].equals(ds['var3']) + assert not interpolated["var1"].equals(ds["var1"]) + assert not interpolated["var2"].equals(ds["var2"]) + assert not interpolated["var3"].equals(ds["var3"]) # attrs should be kept - assert interpolated.attrs['foo'] == 'var' - assert interpolated['var1'].attrs['buz'] == 'var2' + assert interpolated.attrs["foo"] == "var" + assert interpolated["var1"].attrs["buz"] == "var2" -@pytest.mark.parametrize('case', [0, 3]) +@pytest.mark.parametrize("case", [0, 3]) def test_interpolate_dimorder(case): """ Make sure the resultant dimension order is consistent with .sel() """ if not has_scipy: - pytest.skip('scipy is not installed.') + pytest.skip("scipy is not installed.") da = get_example_data(case) - new_x = xr.DataArray([0, 1, 2], dims='x') - assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + new_x = xr.DataArray([0, 1, 2], dims="x") + assert da.interp(x=new_x).dims == da.sel(x=new_x, method="nearest").dims - new_y = xr.DataArray([0, 1, 2], dims='y') + new_y = xr.DataArray([0, 1, 2], dims="y") actual = da.interp(x=new_x, y=new_y).dims - expected = da.sel(x=new_x, y=new_y, method='nearest').dims + expected = da.sel(x=new_x, y=new_y, method="nearest").dims assert actual == expected # reversed order actual = da.interp(y=new_y, x=new_x).dims - expected = da.sel(y=new_y, x=new_x, method='nearest').dims + expected = da.sel(y=new_y, x=new_x, method="nearest").dims assert actual == expected - new_x = xr.DataArray([0, 1, 2], dims='a') - assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims - assert da.interp(y=new_x).dims == da.sel(y=new_x, method='nearest').dims - new_y = xr.DataArray([0, 1, 2], dims='a') + new_x = xr.DataArray([0, 1, 2], dims="a") + assert da.interp(x=new_x).dims == da.sel(x=new_x, method="nearest").dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method="nearest").dims + new_y = xr.DataArray([0, 1, 2], dims="a") actual = da.interp(x=new_x, y=new_y).dims - expected = da.sel(x=new_x, y=new_y, method='nearest').dims + expected = da.sel(x=new_x, y=new_y, method="nearest").dims assert actual == expected - new_x = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) - assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims - assert da.interp(y=new_x).dims == da.sel(y=new_x, method='nearest').dims + new_x = xr.DataArray([[0], [1], [2]], dims=["a", "b"]) + assert da.interp(x=new_x).dims == da.sel(x=new_x, method="nearest").dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method="nearest").dims if case == 3: - new_x = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) - new_z = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + new_x = xr.DataArray([[0], [1], [2]], dims=["a", "b"]) + new_z = xr.DataArray([[0], [1], [2]], dims=["a", "b"]) actual = da.interp(x=new_x, z=new_z).dims - expected = da.sel(x=new_x, z=new_z, method='nearest').dims + expected = da.sel(x=new_x, z=new_z, method="nearest").dims assert actual == expected actual = da.interp(z=new_z, x=new_x).dims - expected = da.sel(z=new_z, x=new_x, method='nearest').dims + expected = da.sel(z=new_z, x=new_x, method="nearest").dims assert actual == expected actual = da.interp(x=0.5, z=new_z).dims - expected = da.sel(x=0.5, z=new_z, method='nearest').dims + expected = da.sel(x=0.5, z=new_z, method="nearest").dims assert actual == expected @requires_scipy def test_interp_like(): ds = create_test_data() - ds.attrs['foo'] = 'var' - ds['var1'].attrs['buz'] = 'var2' + ds.attrs["foo"] = "var" + ds["var1"].attrs["buz"] = "var2" - other = xr.DataArray(np.random.randn(3), dims=['dim2'], - coords={'dim2': [0, 1, 2]}) + other = xr.DataArray(np.random.randn(3), dims=["dim2"], coords={"dim2": [0, 1, 2]}) interpolated = ds.interp_like(other) - assert_allclose(interpolated['var1'], - ds['var1'].interp(dim2=other['dim2'])) - assert_allclose(interpolated['var1'], - ds['var1'].interp_like(other)) - assert interpolated['var3'].equals(ds['var3']) + assert_allclose(interpolated["var1"], ds["var1"].interp(dim2=other["dim2"])) + assert_allclose(interpolated["var1"], ds["var1"].interp_like(other)) + assert interpolated["var3"].equals(ds["var3"]) # attrs should be kept - assert interpolated.attrs['foo'] == 'var' - assert interpolated['var1'].attrs['buz'] == 'var2' + assert interpolated.attrs["foo"] == "var" + assert interpolated["var1"].attrs["buz"] == "var2" - other = xr.DataArray(np.random.randn(3), dims=['dim3'], - coords={'dim3': ['a', 'b', 'c']}) + other = xr.DataArray( + np.random.randn(3), dims=["dim3"], coords={"dim3": ["a", "b", "c"]} + ) actual = ds.interp_like(other) expected = ds.reindex_like(other) @@ -465,45 +511,60 @@ def test_interp_like(): @requires_scipy -@pytest.mark.parametrize('x_new, expected', [ - (pd.date_range('2000-01-02', periods=3), [1, 2, 3]), - (np.array([np.datetime64('2000-01-01T12:00'), - np.datetime64('2000-01-02T12:00')]), [0.5, 1.5]), - (['2000-01-01T12:00', '2000-01-02T12:00'], [0.5, 1.5]), - (['2000-01-01T12:00'], 0.5), - pytest.param('2000-01-01T12:00', 0.5, marks=pytest.mark.xfail) -]) +@pytest.mark.parametrize( + "x_new, expected", + [ + (pd.date_range("2000-01-02", periods=3), [1, 2, 3]), + ( + np.array( + [np.datetime64("2000-01-01T12:00"), np.datetime64("2000-01-02T12:00")] + ), + [0.5, 1.5], + ), + (["2000-01-01T12:00", "2000-01-02T12:00"], [0.5, 1.5]), + (["2000-01-01T12:00"], 0.5), + pytest.param("2000-01-01T12:00", 0.5, marks=pytest.mark.xfail), + ], +) def test_datetime(x_new, expected): - da = xr.DataArray(np.arange(24), dims='time', - coords={'time': pd.date_range('2000-01-01', periods=24)}) + da = xr.DataArray( + np.arange(24), + dims="time", + coords={"time": pd.date_range("2000-01-01", periods=24)}, + ) actual = da.interp(time=x_new) - expected_da = xr.DataArray(np.atleast_1d(expected), dims=['time'], - coords={'time': (np.atleast_1d(x_new) - .astype('datetime64[ns]'))}) + expected_da = xr.DataArray( + np.atleast_1d(expected), + dims=["time"], + coords={"time": (np.atleast_1d(x_new).astype("datetime64[ns]"))}, + ) assert_allclose(actual, expected_da) @requires_scipy def test_datetime_single_string(): - da = xr.DataArray(np.arange(24), dims='time', - coords={'time': pd.date_range('2000-01-01', periods=24)}) - actual = da.interp(time='2000-01-01T12:00') + da = xr.DataArray( + np.arange(24), + dims="time", + coords={"time": pd.date_range("2000-01-01", periods=24)}, + ) + actual = da.interp(time="2000-01-01T12:00") expected = xr.DataArray(0.5) - assert_allclose(actual.drop('time'), expected) + assert_allclose(actual.drop("time"), expected) @requires_cftime @requires_scipy def test_cftime(): - times = xr.cftime_range('2000', periods=24, freq='D') - da = xr.DataArray(np.arange(24), coords=[times], dims='time') + times = xr.cftime_range("2000", periods=24, freq="D") + da = xr.DataArray(np.arange(24), coords=[times], dims="time") - times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D') + times_new = xr.cftime_range("2000-01-01T12:00:00", periods=3, freq="D") actual = da.interp(time=times_new) - expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=['time']) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=["time"]) assert_allclose(actual, expected) @@ -511,11 +572,12 @@ def test_cftime(): @requires_cftime @requires_scipy def test_cftime_type_error(): - times = xr.cftime_range('2000', periods=24, freq='D') - da = xr.DataArray(np.arange(24), coords=[times], dims='time') + times = xr.cftime_range("2000", periods=24, freq="D") + da = xr.DataArray(np.arange(24), coords=[times], dims="time") - times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D', - calendar='noleap') + times_new = xr.cftime_range( + "2000-01-01T12:00:00", periods=3, freq="D", calendar="noleap" + ) with pytest.raises(TypeError): da.interp(time=times_new) @@ -525,17 +587,18 @@ def test_cftime_type_error(): def test_cftime_list_of_strings(): from cftime import DatetimeProlepticGregorian - times = xr.cftime_range('2000', periods=24, freq='D', - calendar='proleptic_gregorian') - da = xr.DataArray(np.arange(24), coords=[times], dims='time') + times = xr.cftime_range( + "2000", periods=24, freq="D", calendar="proleptic_gregorian" + ) + da = xr.DataArray(np.arange(24), coords=[times], dims="time") - times_new = ['2000-01-01T12:00', '2000-01-02T12:00', '2000-01-03T12:00'] + times_new = ["2000-01-01T12:00", "2000-01-02T12:00", "2000-01-03T12:00"] actual = da.interp(time=times_new) times_new_array = _parse_array_of_cftime_strings( - np.array(times_new), DatetimeProlepticGregorian) - expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], - dims=['time']) + np.array(times_new), DatetimeProlepticGregorian + ) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], dims=["time"]) assert_allclose(actual, expected) @@ -545,24 +608,29 @@ def test_cftime_list_of_strings(): def test_cftime_single_string(): from cftime import DatetimeProlepticGregorian - times = xr.cftime_range('2000', periods=24, freq='D', - calendar='proleptic_gregorian') - da = xr.DataArray(np.arange(24), coords=[times], dims='time') + times = xr.cftime_range( + "2000", periods=24, freq="D", calendar="proleptic_gregorian" + ) + da = xr.DataArray(np.arange(24), coords=[times], dims="time") - times_new = '2000-01-01T12:00' + times_new = "2000-01-01T12:00" actual = da.interp(time=times_new) times_new_array = _parse_array_of_cftime_strings( - np.array(times_new), DatetimeProlepticGregorian) - expected = xr.DataArray(0.5, coords={'time': times_new_array}) + np.array(times_new), DatetimeProlepticGregorian + ) + expected = xr.DataArray(0.5, coords={"time": times_new_array}) assert_allclose(actual, expected) @requires_scipy def test_datetime_to_non_datetime_error(): - da = xr.DataArray(np.arange(24), dims='time', - coords={'time': pd.date_range('2000-01-01', periods=24)}) + da = xr.DataArray( + np.arange(24), + dims="time", + coords={"time": pd.date_range("2000-01-01", periods=24)}, + ) with pytest.raises(TypeError): da.interp(time=0.5) @@ -570,8 +638,8 @@ def test_datetime_to_non_datetime_error(): @requires_cftime @requires_scipy def test_cftime_to_non_cftime_error(): - times = xr.cftime_range('2000', periods=24, freq='D') - da = xr.DataArray(np.arange(24), coords=[times], dims='time') + times = xr.cftime_range("2000", periods=24, freq="D") + da = xr.DataArray(np.arange(24), coords=[times], dims="time") with pytest.raises(TypeError): da.interp(time=0.5) @@ -581,10 +649,16 @@ def test_cftime_to_non_cftime_error(): def test_datetime_interp_noerror(): # GH:2667 a = xr.DataArray( - np.arange(21).reshape(3, 7), dims=['x', 'time'], - coords={'x': [1, 2, 3], - 'time': pd.date_range('01-01-2001', periods=7, freq='D')}) + np.arange(21).reshape(3, 7), + dims=["x", "time"], + coords={ + "x": [1, 2, 3], + "time": pd.date_range("01-01-2001", periods=7, freq="D"), + }, + ) xi = xr.DataArray( - np.linspace(1, 3, 50), dims=['time'], - coords={'time': pd.date_range('01-01-2001', periods=50, freq='H')}) + np.linspace(1, 3, 50), + dims=["time"], + coords={"time": pd.date_range("01-01-2001", periods=50, freq="H")}, + ) a.interp(x=xi, time=xi.time) # should not raise an error diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 20e0fae8daf..ed1453ce95d 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -11,139 +11,134 @@ class TestMergeInternals: def test_broadcast_dimension_size(self): actual = merge.broadcast_dimension_size( - [xr.Variable('x', [1]), xr.Variable('y', [2, 1])]) - assert actual == {'x': 1, 'y': 2} + [xr.Variable("x", [1]), xr.Variable("y", [2, 1])] + ) + assert actual == {"x": 1, "y": 2} actual = merge.broadcast_dimension_size( - [xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2, 1])]) - assert actual == {'x': 1, 'y': 2} + [xr.Variable(("x", "y"), [[1, 2]]), xr.Variable("y", [2, 1])] + ) + assert actual == {"x": 1, "y": 2} with pytest.raises(ValueError): merge.broadcast_dimension_size( - [xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2])]) + [xr.Variable(("x", "y"), [[1, 2]]), xr.Variable("y", [2])] + ) class TestMergeFunction: def test_merge_arrays(self): data = create_test_data() actual = xr.merge([data.var1, data.var2]) - expected = data[['var1', 'var2']] + expected = data[["var1", "var2"]] assert actual.identical(expected) def test_merge_datasets(self): data = create_test_data() - actual = xr.merge([data[['var1']], data[['var2']]]) - expected = data[['var1', 'var2']] + actual = xr.merge([data[["var1"]], data[["var2"]]]) + expected = data[["var1", "var2"]] assert actual.identical(expected) actual = xr.merge([data, data]) assert actual.identical(data) def test_merge_dataarray_unnamed(self): - data = xr.DataArray([1, 2], dims='x') - with raises_regex( - ValueError, 'without providing an explicit name'): + data = xr.DataArray([1, 2], dims="x") + with raises_regex(ValueError, "without providing an explicit name"): xr.merge([data]) def test_merge_dicts_simple(self): - actual = xr.merge([{'foo': 0}, {'bar': 'one'}, {'baz': 3.5}]) - expected = xr.Dataset({'foo': 0, 'bar': 'one', 'baz': 3.5}) + actual = xr.merge([{"foo": 0}, {"bar": "one"}, {"baz": 3.5}]) + expected = xr.Dataset({"foo": 0, "bar": "one", "baz": 3.5}) assert actual.identical(expected) def test_merge_dicts_dims(self): - actual = xr.merge([{'y': ('x', [13])}, {'x': [12]}]) - expected = xr.Dataset({'x': [12], 'y': ('x', [13])}) + actual = xr.merge([{"y": ("x", [13])}, {"x": [12]}]) + expected = xr.Dataset({"x": [12], "y": ("x", [13])}) assert actual.identical(expected) def test_merge_error(self): - ds = xr.Dataset({'x': 0}) + ds = xr.Dataset({"x": 0}) with pytest.raises(xr.MergeError): xr.merge([ds, ds + 1]) def test_merge_alignment_error(self): - ds = xr.Dataset(coords={'x': [1, 2]}) - other = xr.Dataset(coords={'x': [2, 3]}) - with raises_regex(ValueError, 'indexes .* not equal'): - xr.merge([ds, other], join='exact') + ds = xr.Dataset(coords={"x": [1, 2]}) + other = xr.Dataset(coords={"x": [2, 3]}) + with raises_regex(ValueError, "indexes .* not equal"): + xr.merge([ds, other], join="exact") def test_merge_wrong_input_error(self): with raises_regex(TypeError, "objects must be an iterable"): xr.merge([1]) - ds = xr.Dataset(coords={'x': [1, 2]}) + ds = xr.Dataset(coords={"x": [1, 2]}) with raises_regex(TypeError, "objects must be an iterable"): - xr.merge({'a': ds}) + xr.merge({"a": ds}) with raises_regex(TypeError, "objects must be an iterable"): xr.merge([ds, 1]) def test_merge_no_conflicts_single_var(self): - ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) - ds2 = xr.Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}) - expected = xr.Dataset({'a': ('x', [1, 2, 3]), 'x': [0, 1, 2]}) - assert expected.identical(xr.merge([ds1, ds2], - compat='no_conflicts')) - assert expected.identical(xr.merge([ds2, ds1], - compat='no_conflicts')) - assert ds1.identical(xr.merge([ds1, ds2], - compat='no_conflicts', - join='left')) - assert ds2.identical(xr.merge([ds1, ds2], - compat='no_conflicts', - join='right')) - expected = xr.Dataset({'a': ('x', [2]), 'x': [1]}) - assert expected.identical(xr.merge([ds1, ds2], - compat='no_conflicts', - join='inner')) + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"a": ("x", [2, 3]), "x": [1, 2]}) + expected = xr.Dataset({"a": ("x", [1, 2, 3]), "x": [0, 1, 2]}) + assert expected.identical(xr.merge([ds1, ds2], compat="no_conflicts")) + assert expected.identical(xr.merge([ds2, ds1], compat="no_conflicts")) + assert ds1.identical(xr.merge([ds1, ds2], compat="no_conflicts", join="left")) + assert ds2.identical(xr.merge([ds1, ds2], compat="no_conflicts", join="right")) + expected = xr.Dataset({"a": ("x", [2]), "x": [1]}) + assert expected.identical( + xr.merge([ds1, ds2], compat="no_conflicts", join="inner") + ) with pytest.raises(xr.MergeError): - ds3 = xr.Dataset({'a': ('x', [99, 3]), 'x': [1, 2]}) - xr.merge([ds1, ds3], compat='no_conflicts') + ds3 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) + xr.merge([ds1, ds3], compat="no_conflicts") with pytest.raises(xr.MergeError): - ds3 = xr.Dataset({'a': ('y', [2, 3]), 'y': [1, 2]}) - xr.merge([ds1, ds3], compat='no_conflicts') + ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) + xr.merge([ds1, ds3], compat="no_conflicts") def test_merge_no_conflicts_multi_var(self): data = create_test_data() data1 = data.copy(deep=True) data2 = data.copy(deep=True) - expected = data[['var1', 'var2']] - actual = xr.merge([data1.var1, data2.var2], compat='no_conflicts') + expected = data[["var1", "var2"]] + actual = xr.merge([data1.var1, data2.var2], compat="no_conflicts") assert expected.identical(actual) - data1['var1'][:, :5] = np.nan - data2['var1'][:, 5:] = np.nan - data1['var2'][:4, :] = np.nan - data2['var2'][4:, :] = np.nan - del data2['var3'] + data1["var1"][:, :5] = np.nan + data2["var1"][:, 5:] = np.nan + data1["var2"][:4, :] = np.nan + data2["var2"][4:, :] = np.nan + del data2["var3"] - actual = xr.merge([data1, data2], compat='no_conflicts') + actual = xr.merge([data1, data2], compat="no_conflicts") assert data.equals(actual) def test_merge_no_conflicts_preserve_attrs(self): - data = xr.Dataset({'x': ([], 0, {'foo': 'bar'})}) + data = xr.Dataset({"x": ([], 0, {"foo": "bar"})}) actual = xr.merge([data, data]) assert data.identical(actual) def test_merge_no_conflicts_broadcast(self): - datasets = [xr.Dataset({'x': ('y', [0])}), xr.Dataset({'x': np.nan})] + datasets = [xr.Dataset({"x": ("y", [0])}), xr.Dataset({"x": np.nan})] actual = xr.merge(datasets) - expected = xr.Dataset({'x': ('y', [0])}) + expected = xr.Dataset({"x": ("y", [0])}) assert expected.identical(actual) - datasets = [xr.Dataset({'x': ('y', [np.nan])}), xr.Dataset({'x': 0})] + datasets = [xr.Dataset({"x": ("y", [np.nan])}), xr.Dataset({"x": 0})] actual = xr.merge(datasets) assert expected.identical(actual) class TestMergeMethod: - def test_merge(self): data = create_test_data() - ds1 = data[['var1']] - ds2 = data[['var3']] - expected = data[['var1', 'var3']] + ds1 = data[["var1"]] + ds2 = data[["var3"]] + expected = data[["var1", "var3"]] actual = ds1.merge(ds2) assert expected.identical(actual) @@ -158,17 +153,15 @@ def test_merge(self): assert data.identical(actual) with pytest.raises(ValueError): - ds1.merge(ds2.rename({'var3': 'var1'})) - with raises_regex( - ValueError, 'should be coordinates or not'): + ds1.merge(ds2.rename({"var3": "var1"})) + with raises_regex(ValueError, "should be coordinates or not"): data.reset_coords().merge(data) - with raises_regex( - ValueError, 'should be coordinates or not'): + with raises_regex(ValueError, "should be coordinates or not"): data.merge(data.reset_coords()) def test_merge_broadcast_equals(self): - ds1 = xr.Dataset({'x': 0}) - ds2 = xr.Dataset({'x': ('y', [0, 0])}) + ds1 = xr.Dataset({"x": 0}) + ds2 = xr.Dataset({"x": ("y", [0, 0])}) actual = ds1.merge(ds2) assert ds2.identical(actual) @@ -179,86 +172,82 @@ def test_merge_broadcast_equals(self): actual.update(ds2) assert ds2.identical(actual) - ds1 = xr.Dataset({'x': np.nan}) - ds2 = xr.Dataset({'x': ('y', [np.nan, np.nan])}) + ds1 = xr.Dataset({"x": np.nan}) + ds2 = xr.Dataset({"x": ("y", [np.nan, np.nan])}) actual = ds1.merge(ds2) assert ds2.identical(actual) def test_merge_compat(self): - ds1 = xr.Dataset({'x': 0}) - ds2 = xr.Dataset({'x': 1}) - for compat in ['broadcast_equals', 'equals', 'identical', - 'no_conflicts']: + ds1 = xr.Dataset({"x": 0}) + ds2 = xr.Dataset({"x": 1}) + for compat in ["broadcast_equals", "equals", "identical", "no_conflicts"]: with pytest.raises(xr.MergeError): ds1.merge(ds2, compat=compat) - ds2 = xr.Dataset({'x': [0, 0]}) - for compat in ['equals', 'identical']: - with raises_regex( - ValueError, 'should be coordinates or not'): + ds2 = xr.Dataset({"x": [0, 0]}) + for compat in ["equals", "identical"]: + with raises_regex(ValueError, "should be coordinates or not"): ds1.merge(ds2, compat=compat) - ds2 = xr.Dataset({'x': ((), 0, {'foo': 'bar'})}) + ds2 = xr.Dataset({"x": ((), 0, {"foo": "bar"})}) with pytest.raises(xr.MergeError): - ds1.merge(ds2, compat='identical') + ds1.merge(ds2, compat="identical") - with raises_regex(ValueError, 'compat=.* invalid'): - ds1.merge(ds2, compat='foobar') + with raises_regex(ValueError, "compat=.* invalid"): + ds1.merge(ds2, compat="foobar") def test_merge_auto_align(self): - ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) - ds2 = xr.Dataset({'b': ('x', [3, 4]), 'x': [1, 2]}) - expected = xr.Dataset({'a': ('x', [1, 2, np.nan]), - 'b': ('x', [np.nan, 3, 4])}, - {'x': [0, 1, 2]}) + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]}) + expected = xr.Dataset( + {"a": ("x", [1, 2, np.nan]), "b": ("x", [np.nan, 3, 4])}, {"x": [0, 1, 2]} + ) assert expected.identical(ds1.merge(ds2)) assert expected.identical(ds2.merge(ds1)) expected = expected.isel(x=slice(2)) - assert expected.identical(ds1.merge(ds2, join='left')) - assert expected.identical(ds2.merge(ds1, join='right')) + assert expected.identical(ds1.merge(ds2, join="left")) + assert expected.identical(ds2.merge(ds1, join="right")) expected = expected.isel(x=slice(1, 2)) - assert expected.identical(ds1.merge(ds2, join='inner')) - assert expected.identical(ds2.merge(ds1, join='inner')) + assert expected.identical(ds1.merge(ds2, join="inner")) + assert expected.identical(ds2.merge(ds1, join="inner")) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_merge_fill_value(self, fill_value): - ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) - ds2 = xr.Dataset({'b': ('x', [3, 4]), 'x': [1, 2]}) + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]}) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array fill_value = np.nan - expected = xr.Dataset({'a': ('x', [1, 2, fill_value]), - 'b': ('x', [fill_value, 3, 4])}, - {'x': [0, 1, 2]}) + expected = xr.Dataset( + {"a": ("x", [1, 2, fill_value]), "b": ("x", [fill_value, 3, 4])}, + {"x": [0, 1, 2]}, + ) assert expected.identical(ds1.merge(ds2, fill_value=fill_value)) assert expected.identical(ds2.merge(ds1, fill_value=fill_value)) assert expected.identical(xr.merge([ds1, ds2], fill_value=fill_value)) def test_merge_no_conflicts(self): - ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) - ds2 = xr.Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}) - expected = xr.Dataset({'a': ('x', [1, 2, 3]), 'x': [0, 1, 2]}) + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"a": ("x", [2, 3]), "x": [1, 2]}) + expected = xr.Dataset({"a": ("x", [1, 2, 3]), "x": [0, 1, 2]}) - assert expected.identical(ds1.merge(ds2, compat='no_conflicts')) - assert expected.identical(ds2.merge(ds1, compat='no_conflicts')) + assert expected.identical(ds1.merge(ds2, compat="no_conflicts")) + assert expected.identical(ds2.merge(ds1, compat="no_conflicts")) - assert ds1.identical(ds1.merge(ds2, compat='no_conflicts', - join='left')) + assert ds1.identical(ds1.merge(ds2, compat="no_conflicts", join="left")) - assert ds2.identical(ds1.merge(ds2, compat='no_conflicts', - join='right')) + assert ds2.identical(ds1.merge(ds2, compat="no_conflicts", join="right")) - expected2 = xr.Dataset({'a': ('x', [2]), 'x': [1]}) - assert expected2.identical(ds1.merge(ds2, compat='no_conflicts', - join='inner')) + expected2 = xr.Dataset({"a": ("x", [2]), "x": [1]}) + assert expected2.identical(ds1.merge(ds2, compat="no_conflicts", join="inner")) with pytest.raises(xr.MergeError): - ds3 = xr.Dataset({'a': ('x', [99, 3]), 'x': [1, 2]}) - ds1.merge(ds3, compat='no_conflicts') + ds3 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) + ds1.merge(ds3, compat="no_conflicts") with pytest.raises(xr.MergeError): - ds3 = xr.Dataset({'a': ('y', [2, 3]), 'y': [1, 2]}) - ds1.merge(ds3, compat='no_conflicts') + ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) + ds1.merge(ds3, compat="no_conflicts") diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 2fd2b6c44b4..cfce5d6f645 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -5,32 +5,36 @@ import pytest import xarray as xr -from xarray.core.missing import ( - NumpyInterpolator, ScipyInterpolator, SplineInterpolator) +from xarray.core.missing import NumpyInterpolator, ScipyInterpolator, SplineInterpolator from xarray.core.pycompat import dask_array_type from xarray.tests import ( - assert_array_equal, assert_equal, raises_regex, requires_bottleneck, - requires_dask, requires_scipy) + assert_array_equal, + assert_equal, + raises_regex, + requires_bottleneck, + requires_dask, + requires_scipy, +) @pytest.fixture def da(): - return xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], - dims='time') + return xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") @pytest.fixture def ds(): ds = xr.Dataset() - ds['var1'] = xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], - dims='time') - ds['var2'] = xr.DataArray([10, np.nan, 11, 12, np.nan, 13, 14, 15, np.nan, - 16, 17], dims='x') + ds["var1"] = xr.DataArray( + [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" + ) + ds["var2"] = xr.DataArray( + [10, np.nan, 11, 12, np.nan, 13, 14, 15, np.nan, 16, 17], dims="x" + ) return ds -def make_interpolate_example_data(shape, frac_nan, seed=12345, - non_uniform=False): +def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False): rs = np.random.RandomState(seed) vals = rs.normal(size=shape) if frac_nan == 1: @@ -53,13 +57,11 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, if non_uniform: # construct a datetime index that has irregular spacing - deltas = pd.TimedeltaIndex(unit='d', data=rs.normal(size=shape[0], - scale=10)) - coords = {'time': (pd.Timestamp('2000-01-01') + deltas).sort_values()} + deltas = pd.TimedeltaIndex(unit="d", data=rs.normal(size=shape[0], scale=10)) + coords = {"time": (pd.Timestamp("2000-01-01") + deltas).sort_values()} else: - coords = {'time': pd.date_range('2000-01-01', freq='D', - periods=shape[0])} - da = xr.DataArray(vals, dims=('time', 'x'), coords=coords) + coords = {"time": pd.date_range("2000-01-01", freq="D", periods=shape[0])} + da = xr.DataArray(vals, dims=("time", "x"), coords=coords) df = da.to_pandas() return da, df @@ -69,18 +71,17 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, def test_interpolate_pd_compat(): shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] frac_nans = [0, 0.5, 1] - methods = ['linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'] + methods = ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] - for (shape, frac_nan, method) in itertools.product(shapes, frac_nans, - methods): + for (shape, frac_nan, method) in itertools.product(shapes, frac_nans, methods): da, df = make_interpolate_example_data(shape, frac_nan) - for dim in ['time', 'x']: - actual = da.interpolate_na(method=method, dim=dim, - fill_value=np.nan) - expected = df.interpolate(method=method, axis=da.get_axis_num(dim), - fill_value=(np.nan, np.nan)) + for dim in ["time", "x"]: + actual = da.interpolate_na(method=method, dim=dim, fill_value=np.nan) + expected = df.interpolate( + method=method, axis=da.get_axis_num(dim), fill_value=(np.nan, np.nan) + ) # Note, Pandas does some odd things with the left/right fill_value # for the linear methods. This next line inforces the xarray # fill_value convention on the pandas output. Therefore, this test @@ -91,34 +92,33 @@ def test_interpolate_pd_compat(): @requires_scipy -@pytest.mark.parametrize('method', ['barycentric', 'krog', - 'pchip', 'spline', 'akima']) +@pytest.mark.parametrize("method", ["barycentric", "krog", "pchip", "spline", "akima"]) def test_scipy_methods_function(method): # Note: Pandas does some wacky things with these methods and the full # integration tests wont work. da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) - actual = da.interpolate_na(method=method, dim='time') - assert (da.count('time') <= actual.count('time')).all() + actual = da.interpolate_na(method=method, dim="time") + assert (da.count("time") <= actual.count("time")).all() @requires_scipy def test_interpolate_pd_compat_non_uniform_index(): shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] frac_nans = [0, 0.5, 1] - methods = ['time', 'index', 'values'] + methods = ["time", "index", "values"] - for (shape, frac_nan, method) in itertools.product(shapes, frac_nans, - methods): + for (shape, frac_nan, method) in itertools.product(shapes, frac_nans, methods): - da, df = make_interpolate_example_data(shape, frac_nan, - non_uniform=True) - for dim in ['time', 'x']: - if method == 'time' and dim != 'time': + da, df = make_interpolate_example_data(shape, frac_nan, non_uniform=True) + for dim in ["time", "x"]: + if method == "time" and dim != "time": continue - actual = da.interpolate_na(method='linear', dim=dim, - use_coordinate=True, fill_value=np.nan) - expected = df.interpolate(method=method, axis=da.get_axis_num(dim), - fill_value=np.nan) + actual = da.interpolate_na( + method="linear", dim=dim, use_coordinate=True, fill_value=np.nan + ) + expected = df.interpolate( + method=method, axis=da.get_axis_num(dim), fill_value=np.nan + ) # Note, Pandas does some odd things with the left/right fill_value # for the linear methods. This next line inforces the xarray @@ -135,80 +135,83 @@ def test_interpolate_pd_compat_polynomial(): frac_nans = [0, 0.5, 1] orders = [1, 2, 3] - for (shape, frac_nan, order) in itertools.product(shapes, frac_nans, - orders): + for (shape, frac_nan, order) in itertools.product(shapes, frac_nans, orders): da, df = make_interpolate_example_data(shape, frac_nan) - for dim in ['time', 'x']: - actual = da.interpolate_na(method='polynomial', order=order, - dim=dim, use_coordinate=False) - expected = df.interpolate(method='polynomial', order=order, - axis=da.get_axis_num(dim)) + for dim in ["time", "x"]: + actual = da.interpolate_na( + method="polynomial", order=order, dim=dim, use_coordinate=False + ) + expected = df.interpolate( + method="polynomial", order=order, axis=da.get_axis_num(dim) + ) np.testing.assert_allclose(actual.values, expected.values) @requires_scipy def test_interpolate_unsorted_index_raises(): vals = np.array([1, 2, 3], dtype=np.float64) - expected = xr.DataArray(vals, dims='x', coords={'x': [2, 1, 3]}) - with raises_regex(ValueError, 'Index must be monotonicly increasing'): - expected.interpolate_na(dim='x', method='index') + expected = xr.DataArray(vals, dims="x", coords={"x": [2, 1, 3]}) + with raises_regex(ValueError, "Index must be monotonicly increasing"): + expected.interpolate_na(dim="x", method="index") def test_interpolate_no_dim_raises(): - da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims='x') - with raises_regex(NotImplementedError, 'dim is a required argument'): - da.interpolate_na(method='linear') + da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims="x") + with raises_regex(NotImplementedError, "dim is a required argument"): + da.interpolate_na(method="linear") def test_interpolate_invalid_interpolator_raises(): - da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims='x') - with raises_regex(ValueError, 'not a valid'): - da.interpolate_na(dim='x', method='foo') + da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims="x") + with raises_regex(ValueError, "not a valid"): + da.interpolate_na(dim="x", method="foo") def test_interpolate_multiindex_raises(): data = np.random.randn(2, 3) data[1, 1] = np.nan - da = xr.DataArray(data, coords=[('x', ['a', 'b']), ('y', [0, 1, 2])]) - das = da.stack(z=('x', 'y')) - with raises_regex(TypeError, 'Index must be castable to float64'): - das.interpolate_na(dim='z') + da = xr.DataArray(data, coords=[("x", ["a", "b"]), ("y", [0, 1, 2])]) + das = da.stack(z=("x", "y")) + with raises_regex(TypeError, "Index must be castable to float64"): + das.interpolate_na(dim="z") def test_interpolate_2d_coord_raises(): - coords = {'x': xr.Variable(('a', 'b'), np.arange(6).reshape(2, 3)), - 'y': xr.Variable(('a', 'b'), np.arange(6).reshape(2, 3)) * 2} + coords = { + "x": xr.Variable(("a", "b"), np.arange(6).reshape(2, 3)), + "y": xr.Variable(("a", "b"), np.arange(6).reshape(2, 3)) * 2, + } data = np.random.randn(2, 3) data[1, 1] = np.nan - da = xr.DataArray(data, dims=('a', 'b'), coords=coords) - with raises_regex(ValueError, 'interpolation must be 1D'): - da.interpolate_na(dim='a', use_coordinate='x') + da = xr.DataArray(data, dims=("a", "b"), coords=coords) + with raises_regex(ValueError, "interpolation must be 1D"): + da.interpolate_na(dim="a", use_coordinate="x") @requires_scipy def test_interpolate_kwargs(): - da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims='x') - expected = xr.DataArray(np.array([4, 5, 6], dtype=np.float64), dims='x') - actual = da.interpolate_na(dim='x', fill_value='extrapolate') + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + expected = xr.DataArray(np.array([4, 5, 6], dtype=np.float64), dims="x") + actual = da.interpolate_na(dim="x", fill_value="extrapolate") assert_equal(actual, expected) - expected = xr.DataArray(np.array([4, 5, -999], dtype=np.float64), dims='x') - actual = da.interpolate_na(dim='x', fill_value=-999) + expected = xr.DataArray(np.array([4, 5, -999], dtype=np.float64), dims="x") + actual = da.interpolate_na(dim="x", fill_value=-999) assert_equal(actual, expected) def test_interpolate(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) - expected = xr.DataArray(vals, dims='x') + expected = xr.DataArray(vals, dims="x") mvals = vals.copy() mvals[2] = np.nan - missing = xr.DataArray(mvals, dims='x') + missing = xr.DataArray(mvals, dims="x") - actual = missing.interpolate_na(dim='x') + actual = missing.interpolate_na(dim="x") assert_equal(actual, expected) @@ -216,54 +219,59 @@ def test_interpolate(): def test_interpolate_nonans(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) - expected = xr.DataArray(vals, dims='x') - actual = expected.interpolate_na(dim='x') + expected = xr.DataArray(vals, dims="x") + actual = expected.interpolate_na(dim="x") assert_equal(actual, expected) @requires_scipy def test_interpolate_allnans(): vals = np.full(6, np.nan, dtype=np.float64) - expected = xr.DataArray(vals, dims='x') - actual = expected.interpolate_na(dim='x') + expected = xr.DataArray(vals, dims="x") + actual = expected.interpolate_na(dim="x") assert_equal(actual, expected) @requires_bottleneck def test_interpolate_limits(): - da = xr.DataArray(np.array([1, 2, np.nan, np.nan, np.nan, 6], - dtype=np.float64), dims='x') + da = xr.DataArray( + np.array([1, 2, np.nan, np.nan, np.nan, 6], dtype=np.float64), dims="x" + ) - actual = da.interpolate_na(dim='x', limit=None) + actual = da.interpolate_na(dim="x", limit=None) assert actual.isnull().sum() == 0 - actual = da.interpolate_na(dim='x', limit=2) - expected = xr.DataArray(np.array([1, 2, 3, 4, np.nan, 6], - dtype=np.float64), dims='x') + actual = da.interpolate_na(dim="x", limit=2) + expected = xr.DataArray( + np.array([1, 2, 3, 4, np.nan, 6], dtype=np.float64), dims="x" + ) assert_equal(actual, expected) @requires_scipy def test_interpolate_methods(): - for method in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', - 'cubic']: + for method in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: kwargs = {} - da = xr.DataArray(np.array([0, 1, 2, np.nan, np.nan, np.nan, 6, 7, 8], - dtype=np.float64), dims='x') - actual = da.interpolate_na('x', method=method, **kwargs) + da = xr.DataArray( + np.array([0, 1, 2, np.nan, np.nan, np.nan, 6, 7, 8], dtype=np.float64), + dims="x", + ) + actual = da.interpolate_na("x", method=method, **kwargs) assert actual.isnull().sum() == 0 - actual = da.interpolate_na('x', method=method, limit=2, **kwargs) + actual = da.interpolate_na("x", method=method, limit=2, **kwargs) assert actual.isnull().sum() == 1 @requires_scipy def test_interpolators(): - for method, interpolator in [('linear', NumpyInterpolator), - ('linear', ScipyInterpolator), - ('spline', SplineInterpolator)]: + for method, interpolator in [ + ("linear", NumpyInterpolator), + ("linear", ScipyInterpolator), + ("spline", SplineInterpolator), + ]: xi = np.array([-1, 0, 1, 2, 5], dtype=np.float64) yi = np.array([-10, 0, 10, 20, 50], dtype=np.float64) x = np.array([3, 4], dtype=np.float64) @@ -274,40 +282,42 @@ def test_interpolators(): def test_interpolate_use_coordinate(): - xc = xr.Variable('x', [100, 200, 300, 400, 500, 600]) - da = xr.DataArray(np.array([1, 2, np.nan, np.nan, np.nan, 6], - dtype=np.float64), - dims='x', coords={'xc': xc}) + xc = xr.Variable("x", [100, 200, 300, 400, 500, 600]) + da = xr.DataArray( + np.array([1, 2, np.nan, np.nan, np.nan, 6], dtype=np.float64), + dims="x", + coords={"xc": xc}, + ) # use_coordinate == False is same as using the default index - actual = da.interpolate_na(dim='x', use_coordinate=False) - expected = da.interpolate_na(dim='x') + actual = da.interpolate_na(dim="x", use_coordinate=False) + expected = da.interpolate_na(dim="x") assert_equal(actual, expected) # possible to specify non index coordinate - actual = da.interpolate_na(dim='x', use_coordinate='xc') - expected = da.interpolate_na(dim='x') + actual = da.interpolate_na(dim="x", use_coordinate="xc") + expected = da.interpolate_na(dim="x") assert_equal(actual, expected) # possible to specify index coordinate by name - actual = da.interpolate_na(dim='x', use_coordinate='x') - expected = da.interpolate_na(dim='x') + actual = da.interpolate_na(dim="x", use_coordinate="x") + expected = da.interpolate_na(dim="x") assert_equal(actual, expected) @requires_dask def test_interpolate_dask(): da, _ = make_interpolate_example_data((40, 40), 0.5) - da = da.chunk({'x': 5}) - actual = da.interpolate_na('time') - expected = da.load().interpolate_na('time') + da = da.chunk({"x": 5}) + actual = da.interpolate_na("time") + expected = da.load().interpolate_na("time") assert isinstance(actual.data, dask_array_type) assert_equal(actual.compute(), expected) # with limit - da = da.chunk({'x': 5}) - actual = da.interpolate_na('time', limit=3) - expected = da.load().interpolate_na('time', limit=3) + da = da.chunk({"x": 5}) + actual = da.interpolate_na("time", limit=3) + expected = da.load().interpolate_na("time", limit=3) assert isinstance(actual.data, dask_array_type) assert_equal(actual, expected) @@ -315,16 +325,16 @@ def test_interpolate_dask(): @requires_dask def test_interpolate_dask_raises_for_invalid_chunk_dim(): da, _ = make_interpolate_example_data((40, 40), 0.5) - da = da.chunk({'time': 5}) + da = da.chunk({"time": 5}) with raises_regex(ValueError, "dask='parallelized' consists of multiple"): - da.interpolate_na('time') + da.interpolate_na("time") @requires_bottleneck def test_ffill(): - da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims='x') - expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims='x') - actual = da.ffill('x') + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims="x") + actual = da.ffill("x") assert_equal(actual, expected) @@ -332,16 +342,16 @@ def test_ffill(): @requires_dask def test_ffill_dask(): da, _ = make_interpolate_example_data((40, 40), 0.5) - da = da.chunk({'x': 5}) - actual = da.ffill('time') - expected = da.load().ffill('time') + da = da.chunk({"x": 5}) + actual = da.ffill("time") + expected = da.load().ffill("time") assert isinstance(actual.data, dask_array_type) assert_equal(actual, expected) # with limit - da = da.chunk({'x': 5}) - actual = da.ffill('time', limit=3) - expected = da.load().ffill('time', limit=3) + da = da.chunk({"x": 5}) + actual = da.ffill("time", limit=3) + expected = da.load().ffill("time", limit=3) assert isinstance(actual.data, dask_array_type) assert_equal(actual, expected) @@ -350,16 +360,16 @@ def test_ffill_dask(): @requires_dask def test_bfill_dask(): da, _ = make_interpolate_example_data((40, 40), 0.5) - da = da.chunk({'x': 5}) - actual = da.bfill('time') - expected = da.load().bfill('time') + da = da.chunk({"x": 5}) + actual = da.bfill("time") + expected = da.load().bfill("time") assert isinstance(actual.data, dask_array_type) assert_equal(actual, expected) # with limit - da = da.chunk({'x': 5}) - actual = da.bfill('time', limit=3) - expected = da.load().bfill('time', limit=3) + da = da.chunk({"x": 5}) + actual = da.bfill("time", limit=3) + expected = da.load().bfill("time", limit=3) assert isinstance(actual.data, dask_array_type) assert_equal(actual, expected) @@ -368,12 +378,12 @@ def test_bfill_dask(): def test_ffill_bfill_nonans(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) - expected = xr.DataArray(vals, dims='x') + expected = xr.DataArray(vals, dims="x") - actual = expected.ffill(dim='x') + actual = expected.ffill(dim="x") assert_equal(actual, expected) - actual = expected.bfill(dim='x') + actual = expected.bfill(dim="x") assert_equal(actual, expected) @@ -381,50 +391,51 @@ def test_ffill_bfill_nonans(): def test_ffill_bfill_allnans(): vals = np.full(6, np.nan, dtype=np.float64) - expected = xr.DataArray(vals, dims='x') + expected = xr.DataArray(vals, dims="x") - actual = expected.ffill(dim='x') + actual = expected.ffill(dim="x") assert_equal(actual, expected) - actual = expected.bfill(dim='x') + actual = expected.bfill(dim="x") assert_equal(actual, expected) @requires_bottleneck def test_ffill_functions(da): - result = da.ffill('time') + result = da.ffill("time") assert result.isnull().sum() == 0 @requires_bottleneck def test_ffill_limit(): da = xr.DataArray( - [0, np.nan, np.nan, np.nan, np.nan, 3, 4, 5, np.nan, 6, 7], - dims='time') - result = da.ffill('time') - expected = xr.DataArray([0, 0, 0, 0, 0, 3, 4, 5, 5, 6, 7], dims='time') + [0, np.nan, np.nan, np.nan, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" + ) + result = da.ffill("time") + expected = xr.DataArray([0, 0, 0, 0, 0, 3, 4, 5, 5, 6, 7], dims="time") assert_array_equal(result, expected) - result = da.ffill('time', limit=1) + result = da.ffill("time", limit=1) expected = xr.DataArray( - [0, 0, np.nan, np.nan, np.nan, 3, 4, 5, 5, 6, 7], dims='time') + [0, 0, np.nan, np.nan, np.nan, 3, 4, 5, 5, 6, 7], dims="time" + ) assert_array_equal(result, expected) def test_interpolate_dataset(ds): - actual = ds.interpolate_na(dim='time') + actual = ds.interpolate_na(dim="time") # no missing values in var1 - assert actual['var1'].count('time') == actual.dims['time'] + assert actual["var1"].count("time") == actual.dims["time"] # var2 should be the same as it was - assert_array_equal(actual['var2'], ds['var2']) + assert_array_equal(actual["var2"], ds["var2"]) @requires_bottleneck def test_ffill_dataset(ds): - ds.ffill(dim='time') + ds.ffill(dim="time") @requires_bottleneck def test_bfill_dataset(ds): - ds.ffill(dim='time') + ds.ffill(dim="time") diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index d3ad87d0d28..1002a9dd9e3 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,8 +1,7 @@ import numpy as np from numpy.testing import assert_array_equal -from xarray.core.nputils import ( - NumpyVIndexAdapter, _is_contiguous, rolling_window) +from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window def test_is_contiguous(): @@ -34,19 +33,14 @@ def test_vindex(): def test_rolling(): x = np.array([1, 2, 3, 4], dtype=float) - actual = rolling_window(x, axis=-1, window=3, center=True, - fill_value=np.nan) - expected = np.array([[np.nan, 1, 2], - [1, 2, 3], - [2, 3, 4], - [3, 4, np.nan]], dtype=float) + actual = rolling_window(x, axis=-1, window=3, center=True, fill_value=np.nan) + expected = np.array( + [[np.nan, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, np.nan]], dtype=float + ) assert_array_equal(actual, expected) actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) - expected = np.array([[0, 0, 1], - [0, 1, 2], - [1, 2, 3], - [2, 3, 4]], dtype=float) + expected = np.array([[0, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]], dtype=float) assert_array_equal(actual, expected) x = np.stack([x, x * 1.1]) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 34bcba58020..2aa77ecd6b3 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -23,17 +23,17 @@ def test_display_width(): def test_arithmetic_join(): with pytest.raises(ValueError): - xarray.set_options(arithmetic_join='invalid') - with xarray.set_options(arithmetic_join='exact'): - assert OPTIONS['arithmetic_join'] == 'exact' + xarray.set_options(arithmetic_join="invalid") + with xarray.set_options(arithmetic_join="exact"): + assert OPTIONS["arithmetic_join"] == "exact" def test_enable_cftimeindex(): with pytest.raises(ValueError): xarray.set_options(enable_cftimeindex=None) - with pytest.warns(FutureWarning, match='no-op'): + with pytest.warns(FutureWarning, match="no-op"): with xarray.set_options(enable_cftimeindex=True): - assert OPTIONS['enable_cftimeindex'] + assert OPTIONS["enable_cftimeindex"] def test_file_cache_maxsize(): @@ -47,37 +47,35 @@ def test_file_cache_maxsize(): def test_keep_attrs(): with pytest.raises(ValueError): - xarray.set_options(keep_attrs='invalid_str') + xarray.set_options(keep_attrs="invalid_str") with xarray.set_options(keep_attrs=True): - assert OPTIONS['keep_attrs'] + assert OPTIONS["keep_attrs"] with xarray.set_options(keep_attrs=False): - assert not OPTIONS['keep_attrs'] - with xarray.set_options(keep_attrs='default'): + assert not OPTIONS["keep_attrs"] + with xarray.set_options(keep_attrs="default"): assert _get_keep_attrs(default=True) assert not _get_keep_attrs(default=False) def test_nested_options(): - original = OPTIONS['display_width'] + original = OPTIONS["display_width"] with xarray.set_options(display_width=1): - assert OPTIONS['display_width'] == 1 + assert OPTIONS["display_width"] == 1 with xarray.set_options(display_width=2): - assert OPTIONS['display_width'] == 2 - assert OPTIONS['display_width'] == 1 - assert OPTIONS['display_width'] == original + assert OPTIONS["display_width"] == 2 + assert OPTIONS["display_width"] == 1 + assert OPTIONS["display_width"] == original def create_test_dataset_attrs(seed=0): ds = create_test_data(seed) - ds.attrs = {'attr1': 5, 'attr2': 'history', - 'attr3': {'nested': 'more_info'}} + ds.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} return ds -def create_test_dataarray_attrs(seed=0, var='var1'): +def create_test_dataarray_attrs(seed=0, var="var1"): da = create_test_data(seed)[var] - da.attrs = {'attr1': 5, 'attr2': 'history', - 'attr3': {'nested': 'more_info'}} + da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} return da @@ -90,7 +88,7 @@ def test_dataset_attr_retention(self): # Test default behaviour result = ds.mean() assert result.attrs == {} - with xarray.set_options(keep_attrs='default'): + with xarray.set_options(keep_attrs="default"): result = ds.mean() assert result.attrs == {} @@ -110,7 +108,7 @@ def test_dataarray_attr_retention(self): # Test default behaviour result = da.mean() assert result.attrs == {} - with xarray.set_options(keep_attrs='default'): + with xarray.set_options(keep_attrs="default"): result = da.mean() assert result.attrs == {} @@ -123,44 +121,43 @@ def test_dataarray_attr_retention(self): assert result.attrs == {} def test_groupby_attr_retention(self): - da = xarray.DataArray([1, 2, 3], [('x', [1, 1, 2])]) - da.attrs = {'attr1': 5, 'attr2': 'history', - 'attr3': {'nested': 'more_info'}} + da = xarray.DataArray([1, 2, 3], [("x", [1, 1, 2])]) + da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} original_attrs = da.attrs # Test default behaviour - result = da.groupby('x').sum(keep_attrs=True) + result = da.groupby("x").sum(keep_attrs=True) assert result.attrs == original_attrs - with xarray.set_options(keep_attrs='default'): - result = da.groupby('x').sum(keep_attrs=True) + with xarray.set_options(keep_attrs="default"): + result = da.groupby("x").sum(keep_attrs=True) assert result.attrs == original_attrs with xarray.set_options(keep_attrs=True): - result1 = da.groupby('x') + result1 = da.groupby("x") result = result1.sum() assert result.attrs == original_attrs with xarray.set_options(keep_attrs=False): - result = da.groupby('x').sum() + result = da.groupby("x").sum() assert result.attrs == {} def test_concat_attr_retention(self): ds1 = create_test_dataset_attrs() ds2 = create_test_dataset_attrs() - ds2.attrs = {'wrong': 'attributes'} + ds2.attrs = {"wrong": "attributes"} original_attrs = ds1.attrs # Test default behaviour of keeping the attrs of the first # dataset in the supplied list # global keep_attrs option current doesn't affect concat - result = concat([ds1, ds2], dim='dim1') + result = concat([ds1, ds2], dim="dim1") assert result.attrs == original_attrs @pytest.mark.xfail def test_merge_attr_retention(self): - da1 = create_test_dataarray_attrs(var='var1') - da2 = create_test_dataarray_attrs(var='var2') - da2.attrs = {'wrong': 'attributes'} + da1 = create_test_dataarray_attrs(var="var1") + da2 = create_test_dataarray_attrs(var="var2") + da2.attrs = {"wrong": "attributes"} original_attrs = da1.attrs # merge currently discards attrs, and the global keep_attrs diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 172b6025b74..36e7a38151d 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -12,13 +12,24 @@ from xarray.plot.dataset_plot import _infer_meta_data from xarray.plot.plot import _infer_interval_breaks from xarray.plot.utils import ( - _build_discrete_cmap, _color_palette, _determine_cmap_params, - import_seaborn, label_from_attrs) + _build_discrete_cmap, + _color_palette, + _determine_cmap_params, + import_seaborn, + label_from_attrs, +) from . import ( - assert_array_equal, assert_equal, has_nc_time_axis, raises_regex, - requires_cftime, requires_matplotlib, requires_matplotlib2, - requires_nc_time_axis, requires_seaborn) + assert_array_equal, + assert_equal, + has_nc_time_axis, + raises_regex, + requires_cftime, + requires_matplotlib, + requires_matplotlib2, + requires_nc_time_axis, + requires_seaborn, +) # import mpl and change the backend before other mpl imports try: @@ -29,11 +40,11 @@ @pytest.mark.flaky -@pytest.mark.skip(reason='maybe flaky') +@pytest.mark.skip(reason="maybe flaky") def text_in_fig(): - ''' + """ Return the set of all text in the figure - ''' + """ return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)} @@ -43,9 +54,9 @@ def find_possible_colorbars(): def substring_in_axes(substring, ax): - ''' + """ Return True if a substring is found anywhere in an axes - ''' + """ alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)} for txt in alltxt: if substring in txt: @@ -54,11 +65,11 @@ def substring_in_axes(substring, ax): def easy_array(shape, start=0, stop=1): - ''' + """ Make an array with desired shape using np.linspace shape is a tuple like (2, 3) - ''' + """ a = np.linspace(start, stop, num=np.prod(shape)) return a.reshape(shape) @@ -69,7 +80,7 @@ class PlotTestCase: def setup(self): yield # Remove all matplotlib figures - plt.close('all') + plt.close("all") def pass_in_axis(self, plotmethod): fig, axes = plt.subplots(ncols=2) @@ -95,125 +106,126 @@ def setup_array(self): def test_label_from_attrs(self): da = self.darray.copy() - assert '' == label_from_attrs(da) + assert "" == label_from_attrs(da) - da.name = 'a' - da.attrs['units'] = 'a_units' - da.attrs['long_name'] = 'a_long_name' - da.attrs['standard_name'] = 'a_standard_name' - assert 'a_long_name [a_units]' == label_from_attrs(da) + da.name = "a" + da.attrs["units"] = "a_units" + da.attrs["long_name"] = "a_long_name" + da.attrs["standard_name"] = "a_standard_name" + assert "a_long_name [a_units]" == label_from_attrs(da) - da.attrs.pop('long_name') - assert 'a_standard_name [a_units]' == label_from_attrs(da) - da.attrs.pop('units') - assert 'a_standard_name' == label_from_attrs(da) + da.attrs.pop("long_name") + assert "a_standard_name [a_units]" == label_from_attrs(da) + da.attrs.pop("units") + assert "a_standard_name" == label_from_attrs(da) - da.attrs['units'] = 'a_units' - da.attrs.pop('standard_name') - assert 'a [a_units]' == label_from_attrs(da) + da.attrs["units"] = "a_units" + da.attrs.pop("standard_name") + assert "a [a_units]" == label_from_attrs(da) - da.attrs.pop('units') - assert 'a' == label_from_attrs(da) + da.attrs.pop("units") + assert "a" == label_from_attrs(da) def test1d(self): self.darray[:, 0, 0].plot() - with raises_regex(ValueError, 'None'): - self.darray[:, 0, 0].plot(x='dim_1') + with raises_regex(ValueError, "None"): + self.darray[:, 0, 0].plot(x="dim_1") def test_1d_x_y_kw(self): z = np.arange(10) - da = DataArray(np.cos(z), dims=['z'], coords=[z], name='f') + da = DataArray(np.cos(z), dims=["z"], coords=[z], name="f") - xy = [[None, None], - [None, 'z'], - ['z', None]] + xy = [[None, None], [None, "z"], ["z", None]] f, ax = plt.subplots(3, 1) for aa, (x, y) in enumerate(xy): da.plot(x=x, y=y, ax=ax.flat[aa]) - with raises_regex(ValueError, 'cannot'): - da.plot(x='z', y='z') + with raises_regex(ValueError, "cannot"): + da.plot(x="z", y="z") - with raises_regex(ValueError, 'None'): - da.plot(x='f', y='z') + with raises_regex(ValueError, "None"): + da.plot(x="f", y="z") - with raises_regex(ValueError, 'None'): - da.plot(x='z', y='f') + with raises_regex(ValueError, "None"): + da.plot(x="z", y="f") # Test for bug in GH issue #2725 def test_infer_line_data(self): - current = DataArray(name='I', data=np.array([5, 8]), dims=['t'], - coords={'t': (['t'], np.array([0.1, 0.2])), - 'V': (['t'], np.array([100, 200]))}) + current = DataArray( + name="I", + data=np.array([5, 8]), + dims=["t"], + coords={ + "t": (["t"], np.array([0.1, 0.2])), + "V": (["t"], np.array([100, 200])), + }, + ) # Plot current against voltage - line = current.plot.line(x='V')[0] - assert_array_equal(line.get_xdata(), current.coords['V'].values) + line = current.plot.line(x="V")[0] + assert_array_equal(line.get_xdata(), current.coords["V"].values) # Plot current against time line = current.plot.line()[0] - assert_array_equal(line.get_xdata(), current.coords['t'].values) + assert_array_equal(line.get_xdata(), current.coords["t"].values) def test_2d_line(self): - with raises_regex(ValueError, 'hue'): + with raises_regex(ValueError, "hue"): self.darray[:, :, 0].plot.line() - self.darray[:, :, 0].plot.line(hue='dim_1') - self.darray[:, :, 0].plot.line(x='dim_1') - self.darray[:, :, 0].plot.line(y='dim_1') - self.darray[:, :, 0].plot.line(x='dim_0', hue='dim_1') - self.darray[:, :, 0].plot.line(y='dim_0', hue='dim_1') + self.darray[:, :, 0].plot.line(hue="dim_1") + self.darray[:, :, 0].plot.line(x="dim_1") + self.darray[:, :, 0].plot.line(y="dim_1") + self.darray[:, :, 0].plot.line(x="dim_0", hue="dim_1") + self.darray[:, :, 0].plot.line(y="dim_0", hue="dim_1") - with raises_regex(ValueError, 'cannot'): - self.darray[:, :, 0].plot.line(x='dim_1', y='dim_0', hue='dim_1') + with raises_regex(ValueError, "cannot"): + self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1") def test_2d_line_accepts_legend_kw(self): - self.darray[:, :, 0].plot.line(x='dim_0', add_legend=False) + self.darray[:, :, 0].plot.line(x="dim_0", add_legend=False) assert not plt.gca().get_legend() plt.cla() - self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True) + self.darray[:, :, 0].plot.line(x="dim_0", add_legend=True) assert plt.gca().get_legend() # check whether legend title is set - assert (plt.gca().get_legend().get_title().get_text() - == 'dim_1') + assert plt.gca().get_legend().get_title().get_text() == "dim_1" def test_2d_line_accepts_x_kw(self): - self.darray[:, :, 0].plot.line(x='dim_0') - assert plt.gca().get_xlabel() == 'dim_0' + self.darray[:, :, 0].plot.line(x="dim_0") + assert plt.gca().get_xlabel() == "dim_0" plt.cla() - self.darray[:, :, 0].plot.line(x='dim_1') - assert plt.gca().get_xlabel() == 'dim_1' + self.darray[:, :, 0].plot.line(x="dim_1") + assert plt.gca().get_xlabel() == "dim_1" def test_2d_line_accepts_hue_kw(self): - self.darray[:, :, 0].plot.line(hue='dim_0') - assert (plt.gca().get_legend().get_title().get_text() - == 'dim_0') + self.darray[:, :, 0].plot.line(hue="dim_0") + assert plt.gca().get_legend().get_title().get_text() == "dim_0" plt.cla() - self.darray[:, :, 0].plot.line(hue='dim_1') - assert (plt.gca().get_legend().get_title().get_text() - == 'dim_1') + self.darray[:, :, 0].plot.line(hue="dim_1") + assert plt.gca().get_legend().get_title().get_text() == "dim_1" def test_2d_coords_line_plot(self): - lon, lat = np.meshgrid(np.linspace(-20, 20, 5), - np.linspace(0, 30, 4)) + lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4)) lon += lat / 10 lat += lon / 10 - da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'], - coords={'lat': (('y', 'x'), lat), - 'lon': (('y', 'x'), lon)}) + da = xr.DataArray( + np.arange(20).reshape(4, 5), + dims=["y", "x"], + coords={"lat": (("y", "x"), lat), "lon": (("y", "x"), lon)}, + ) - hdl = da.plot.line(x='lon', hue='x') + hdl = da.plot.line(x="lon", hue="x") assert len(hdl) == 5 plt.clf() - hdl = da.plot.line(x='lon', hue='y') + hdl = da.plot.line(x="lon", hue="y") assert len(hdl) == 4 - with pytest.raises( - ValueError, match="For 2D inputs, hue must be a dimension"): - da.plot.line(x='lon', hue='lat') + with pytest.raises(ValueError, match="For 2D inputs, hue must be a dimension"): + da.plot.line(x="lon", hue="lat") def test_2d_before_squeeze(self): a = DataArray(easy_array((1, 5))) @@ -225,7 +237,7 @@ def test2d_uniform_calls_imshow(self): @pytest.mark.slow def test2d_nonuniform_calls_contourf(self): a = self.darray[:, :, 0] - a.coords['dim_1'] = [2, 1, 89] + a.coords["dim_1"] = [2, 1, 89] assert self.contourf_called(a.plot.contourf) def test2d_1d_2d_coordinates_contourf(self): @@ -233,13 +245,11 @@ def test2d_1d_2d_coordinates_contourf(self): depth = easy_array(sz) a = DataArray( easy_array(sz), - dims=['z', 'time'], - coords={ - 'depth': (['z', 'time'], depth), - 'time': np.linspace(0, 1, sz[1]) - }) + dims=["z", "time"], + coords={"depth": (["z", "time"], depth), "time": np.linspace(0, 1, sz[1])}, + ) - a.plot.contourf(x='time', y='depth') + a.plot.contourf(x="time", y="depth") def test3d(self): self.darray.plot() @@ -249,11 +259,13 @@ def test_can_pass_in_axis(self): def test__infer_interval_breaks(self): assert_array_equal([-0.5, 0.5, 1.5], _infer_interval_breaks([0, 1])) - assert_array_equal([-0.5, 0.5, 5.0, 9.5, 10.5], - _infer_interval_breaks([0, 1, 9, 10])) assert_array_equal( - pd.date_range('20000101', periods=4) - np.timedelta64(12, 'h'), - _infer_interval_breaks(pd.date_range('20000101', periods=3))) + [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) + ) + assert_array_equal( + pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), + _infer_interval_breaks(pd.date_range("20000101", periods=3)), + ) # make a bounded 2D array that we will center and re-infer xref, yref = np.meshgrid(np.arange(6), np.arange(5)) @@ -273,71 +285,81 @@ def test__infer_interval_breaks(self): def test_geo_data(self): # Regression test for gh2250 # Realistic coordinates taken from the example dataset - lat = np.array([[16.28, 18.48, 19.58, 19.54, 18.35], - [28.07, 30.52, 31.73, 31.68, 30.37], - [39.65, 42.27, 43.56, 43.51, 42.11], - [50.52, 53.22, 54.55, 54.50, 53.06]]) - lon = np.array([[-126.13, -113.69, -100.92, -88.04, -75.29], - [-129.27, -115.62, -101.54, -87.32, -73.26], - [-133.10, -118.00, -102.31, -86.42, -70.76], - [-137.85, -120.99, -103.28, -85.28, -67.62]]) + lat = np.array( + [ + [16.28, 18.48, 19.58, 19.54, 18.35], + [28.07, 30.52, 31.73, 31.68, 30.37], + [39.65, 42.27, 43.56, 43.51, 42.11], + [50.52, 53.22, 54.55, 54.50, 53.06], + ] + ) + lon = np.array( + [ + [-126.13, -113.69, -100.92, -88.04, -75.29], + [-129.27, -115.62, -101.54, -87.32, -73.26], + [-133.10, -118.00, -102.31, -86.42, -70.76], + [-137.85, -120.99, -103.28, -85.28, -67.62], + ] + ) data = np.sqrt(lon ** 2 + lat ** 2) - da = DataArray(data, dims=('y', 'x'), - coords={'lon': (('y', 'x'), lon), - 'lat': (('y', 'x'), lat)}) - da.plot(x='lon', y='lat') + da = DataArray( + data, + dims=("y", "x"), + coords={"lon": (("y", "x"), lon), "lat": (("y", "x"), lat)}, + ) + da.plot(x="lon", y="lat") ax = plt.gca() assert ax.has_data() - da.plot(x='lat', y='lon') + da.plot(x="lat", y="lon") ax = plt.gca() assert ax.has_data() def test_datetime_dimension(self): nrow = 3 ncol = 4 - time = pd.date_range('2000-01-01', periods=nrow) + time = pd.date_range("2000-01-01", periods=nrow) a = DataArray( - easy_array((nrow, ncol)), - coords=[('time', time), ('y', range(ncol))]) + easy_array((nrow, ncol)), coords=[("time", time), ("y", range(ncol))] + ) a.plot() ax = plt.gca() assert ax.has_data() @pytest.mark.slow - @pytest.mark.filterwarnings('ignore:tight_layout cannot') + @pytest.mark.filterwarnings("ignore:tight_layout cannot") def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) - d = DataArray(a, dims=['y', 'x', 'z']) - d.coords['z'] = list('abcd') - g = d.plot(x='x', y='y', col='z', col_wrap=2, cmap='cool') + d = DataArray(a, dims=["y", "x", "z"]) + d.coords["z"] = list("abcd") + g = d.plot(x="x", y="y", col="z", col_wrap=2, cmap="cool") assert_array_equal(g.axes.shape, [2, 2]) for ax in g.axes.flat: assert ax.has_data() - with raises_regex(ValueError, '[Ff]acet'): - d.plot(x='x', y='y', col='z', ax=plt.gca()) + with raises_regex(ValueError, "[Ff]acet"): + d.plot(x="x", y="y", col="z", ax=plt.gca()) - with raises_regex(ValueError, '[Ff]acet'): - d[0].plot(x='x', y='y', col='z', ax=plt.gca()) + with raises_regex(ValueError, "[Ff]acet"): + d[0].plot(x="x", y="y", col="z", ax=plt.gca()) @pytest.mark.slow @requires_matplotlib2 def test_subplot_kws(self): a = easy_array((10, 15, 4)) - d = DataArray(a, dims=['y', 'x', 'z']) - d.coords['z'] = list('abcd') + d = DataArray(a, dims=["y", "x", "z"]) + d.coords["z"] = list("abcd") g = d.plot( - x='x', - y='y', - col='z', + x="x", + y="y", + col="z", col_wrap=2, - cmap='cool', - subplot_kws=dict(facecolor='r')) + cmap="cool", + subplot_kws=dict(facecolor="r"), + ) for ax in g.axes.flat: # mpl V2 - assert ax.get_facecolor()[0:3] == \ - mpl.colors.to_rgb('r') + assert ax.get_facecolor()[0:3] == mpl.colors.to_rgb("r") @pytest.mark.slow def test_plot_size(self): @@ -353,78 +375,77 @@ def test_plot_size(self): self.darray.plot(size=5, aspect=2) assert tuple(plt.gcf().get_size_inches()) == (10, 5) - with raises_regex(ValueError, 'cannot provide both'): + with raises_regex(ValueError, "cannot provide both"): self.darray.plot(ax=plt.gca(), figsize=(3, 4)) - with raises_regex(ValueError, 'cannot provide both'): + with raises_regex(ValueError, "cannot provide both"): self.darray.plot(size=5, figsize=(3, 4)) - with raises_regex(ValueError, 'cannot provide both'): + with raises_regex(ValueError, "cannot provide both"): self.darray.plot(size=5, ax=plt.gca()) - with raises_regex(ValueError, 'cannot provide `aspect`'): + with raises_regex(ValueError, "cannot provide `aspect`"): self.darray.plot(aspect=1) @pytest.mark.slow - @pytest.mark.filterwarnings('ignore:tight_layout cannot') + @pytest.mark.filterwarnings("ignore:tight_layout cannot") def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) - d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) - g = d.plot(x='x', y='y', col='columns', row='rows') + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = d.plot(x="x", y="y", col="columns", row="rows") assert_array_equal(g.axes.shape, [3, 2]) for ax in g.axes.flat: assert ax.has_data() - with raises_regex(ValueError, '[Ff]acet'): - d.plot(x='x', y='y', col='columns', ax=plt.gca()) + with raises_regex(ValueError, "[Ff]acet"): + d.plot(x="x", y="y", col="columns", ax=plt.gca()) def test_coord_with_interval(self): bins = [-1, 0, 1, 2] - self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot() + self.darray.groupby_bins("dim_0", bins).mean(xr.ALL_DIMS).plot() class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): d = [0, 1.1, 0, 2] - self.darray = DataArray( - d, coords={'period': range(len(d))}, dims='period') - self.darray.period.attrs['units'] = 's' + self.darray = DataArray(d, coords={"period": range(len(d))}, dims="period") + self.darray.period.attrs["units"] = "s" def test_xlabel_is_index_name(self): self.darray.plot() - assert 'period [s]' == plt.gca().get_xlabel() + assert "period [s]" == plt.gca().get_xlabel() def test_no_label_name_on_x_axis(self): - self.darray.plot(y='period') - assert '' == plt.gca().get_xlabel() + self.darray.plot(y="period") + assert "" == plt.gca().get_xlabel() def test_no_label_name_on_y_axis(self): self.darray.plot() - assert '' == plt.gca().get_ylabel() + assert "" == plt.gca().get_ylabel() def test_ylabel_is_data_name(self): - self.darray.name = 'temperature' - self.darray.attrs['units'] = 'degrees_Celsius' + self.darray.name = "temperature" + self.darray.attrs["units"] = "degrees_Celsius" self.darray.plot() - assert 'temperature [degrees_Celsius]' == plt.gca().get_ylabel() + assert "temperature [degrees_Celsius]" == plt.gca().get_ylabel() def test_xlabel_is_data_name(self): - self.darray.name = 'temperature' - self.darray.attrs['units'] = 'degrees_Celsius' - self.darray.plot(y='period') - assert 'temperature [degrees_Celsius]' == plt.gca().get_xlabel() + self.darray.name = "temperature" + self.darray.attrs["units"] = "degrees_Celsius" + self.darray.plot(y="period") + assert "temperature [degrees_Celsius]" == plt.gca().get_xlabel() def test_format_string(self): - self.darray.plot.line('ro') + self.darray.plot.line("ro") def test_can_pass_in_axis(self): self.pass_in_axis(self.darray.plot.line) def test_nonnumeric_index_raises_typeerror(self): - a = DataArray([1, 2, 3], {'letter': ['a', 'b', 'c']}, dims='letter') - with raises_regex(TypeError, r'[Pp]lot'): + a = DataArray([1, 2, 3], {"letter": ["a", "b", "c"]}, dims="letter") + with raises_regex(TypeError, r"[Pp]lot"): a.plot.line() def test_primitive_returned(self): @@ -437,8 +458,8 @@ def test_plot_nans(self): self.darray.plot.line() def test_x_ticks_are_rotated_for_time(self): - time = pd.date_range('2000-01-01', '2000-01-10') - a = DataArray(np.arange(len(time)), [('t', time)]) + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.arange(len(time)), [("t", time)]) a.plot.line() rotation = plt.gca().get_xticklabels()[0].get_rotation() assert rotation != 0 @@ -451,10 +472,10 @@ def test_xyincrease_false_changes_axes(self): assert all(x < 0 for x in diffs) def test_slice_in_title(self): - self.darray.coords['d'] = 10 + self.darray.coords["d"] = 10 self.darray.plot.line() title = plt.gca().get_title() - assert 'd = 10' == title + assert "d = 10" == title class TestPlotStep(PlotTestCase): @@ -467,7 +488,7 @@ def test_step(self): def test_coord_with_interval_step(self): bins = [-1, 0, 1, 2] - self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot.step() + self.darray.groupby_bins("dim_0", bins).mean(xr.ALL_DIMS).plot.step() assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) @@ -480,14 +501,14 @@ def test_3d_array(self): self.darray.plot.hist() def test_xlabel_uses_name(self): - self.darray.name = 'testpoints' - self.darray.attrs['units'] = 'testunits' + self.darray.name = "testpoints" + self.darray.attrs["units"] = "testunits" self.darray.plot.hist() - assert 'testpoints [testunits]' == plt.gca().get_xlabel() + assert "testpoints [testunits]" == plt.gca().get_xlabel() def test_title_is_histogram(self): self.darray.plot.hist() - assert 'Histogram' == plt.gca().get_title() + assert "Histogram" == plt.gca().get_title() def test_can_pass_in_kwargs(self): nbins = 5 @@ -507,8 +528,11 @@ def test_plot_nans(self): self.darray.plot.hist() def test_hist_coord_with_interval(self): - (self.darray.groupby_bins('dim_0', [-1, 0, 1, 2]).mean(xr.ALL_DIMS) - .plot.hist(range=(-1, 2))) + ( + self.darray.groupby_bins("dim_0", [-1, 0, 1, 2]) + .mean(xr.ALL_DIMS) + .plot.hist(range=(-1, 2)) + ) @requires_matplotlib @@ -519,35 +543,35 @@ def setUp(self): def test_robust(self): cmap_params = _determine_cmap_params(self.data, robust=True) - assert cmap_params['vmin'] == np.percentile(self.data, 2) - assert cmap_params['vmax'] == np.percentile(self.data, 98) - assert cmap_params['cmap'] == 'viridis' - assert cmap_params['extend'] == 'both' - assert cmap_params['levels'] is None - assert cmap_params['norm'] is None + assert cmap_params["vmin"] == np.percentile(self.data, 2) + assert cmap_params["vmax"] == np.percentile(self.data, 98) + assert cmap_params["cmap"] == "viridis" + assert cmap_params["extend"] == "both" + assert cmap_params["levels"] is None + assert cmap_params["norm"] is None def test_center(self): cmap_params = _determine_cmap_params(self.data, center=0.5) - assert cmap_params['vmax'] - 0.5 == 0.5 - cmap_params['vmin'] - assert cmap_params['cmap'] == 'RdBu_r' - assert cmap_params['extend'] == 'neither' - assert cmap_params['levels'] is None - assert cmap_params['norm'] is None + assert cmap_params["vmax"] - 0.5 == 0.5 - cmap_params["vmin"] + assert cmap_params["cmap"] == "RdBu_r" + assert cmap_params["extend"] == "neither" + assert cmap_params["levels"] is None + assert cmap_params["norm"] is None def test_cmap_sequential_option(self): - with xr.set_options(cmap_sequential='magma'): + with xr.set_options(cmap_sequential="magma"): cmap_params = _determine_cmap_params(self.data) - assert cmap_params['cmap'] == 'magma' + assert cmap_params["cmap"] == "magma" def test_cmap_sequential_explicit_option(self): with xr.set_options(cmap_sequential=mpl.cm.magma): cmap_params = _determine_cmap_params(self.data) - assert cmap_params['cmap'] == mpl.cm.magma + assert cmap_params["cmap"] == mpl.cm.magma def test_cmap_divergent_option(self): - with xr.set_options(cmap_divergent='magma'): + with xr.set_options(cmap_divergent="magma"): cmap_params = _determine_cmap_params(self.data, center=0.5) - assert cmap_params['cmap'] == 'magma' + assert cmap_params["cmap"] == "magma" def test_nan_inf_are_ignored(self): cmap_params1 = _determine_cmap_params(self.data) @@ -555,8 +579,8 @@ def test_nan_inf_are_ignored(self): data[50:55] = np.nan data[56:60] = np.inf cmap_params2 = _determine_cmap_params(data) - assert cmap_params1['vmin'] == cmap_params2['vmin'] - assert cmap_params1['vmax'] == cmap_params2['vmax'] + assert cmap_params1["vmin"] == cmap_params2["vmin"] + assert cmap_params1["vmax"] == cmap_params2["vmax"] @pytest.mark.slow def test_integer_levels(self): @@ -565,52 +589,49 @@ def test_integer_levels(self): # default is to cover full data range but with no guarantee on Nlevels for level in np.arange(2, 10, dtype=int): cmap_params = _determine_cmap_params(data, levels=level) - assert cmap_params['vmin'] == cmap_params['levels'][0] - assert cmap_params['vmax'] == cmap_params['levels'][-1] - assert cmap_params['extend'] == 'neither' + assert cmap_params["vmin"] == cmap_params["levels"][0] + assert cmap_params["vmax"] == cmap_params["levels"][-1] + assert cmap_params["extend"] == "neither" # with min max we are more strict cmap_params = _determine_cmap_params( - data, levels=5, vmin=0, vmax=5, cmap='Blues') - assert cmap_params['vmin'] == 0 - assert cmap_params['vmax'] == 5 - assert cmap_params['vmin'] == cmap_params['levels'][0] - assert cmap_params['vmax'] == cmap_params['levels'][-1] - assert cmap_params['cmap'].name == 'Blues' - assert cmap_params['extend'] == 'neither' - assert cmap_params['cmap'].N == 4 - assert cmap_params['norm'].N == 5 - - cmap_params = _determine_cmap_params( - data, levels=5, vmin=0.5, vmax=1.5) - assert cmap_params['cmap'].name == 'viridis' - assert cmap_params['extend'] == 'max' + data, levels=5, vmin=0, vmax=5, cmap="Blues" + ) + assert cmap_params["vmin"] == 0 + assert cmap_params["vmax"] == 5 + assert cmap_params["vmin"] == cmap_params["levels"][0] + assert cmap_params["vmax"] == cmap_params["levels"][-1] + assert cmap_params["cmap"].name == "Blues" + assert cmap_params["extend"] == "neither" + assert cmap_params["cmap"].N == 4 + assert cmap_params["norm"].N == 5 + + cmap_params = _determine_cmap_params(data, levels=5, vmin=0.5, vmax=1.5) + assert cmap_params["cmap"].name == "viridis" + assert cmap_params["extend"] == "max" cmap_params = _determine_cmap_params(data, levels=5, vmin=1.5) - assert cmap_params['cmap'].name == 'viridis' - assert cmap_params['extend'] == 'min' + assert cmap_params["cmap"].name == "viridis" + assert cmap_params["extend"] == "min" - cmap_params = _determine_cmap_params( - data, levels=5, vmin=1.3, vmax=1.5) - assert cmap_params['cmap'].name == 'viridis' - assert cmap_params['extend'] == 'both' + cmap_params = _determine_cmap_params(data, levels=5, vmin=1.3, vmax=1.5) + assert cmap_params["cmap"].name == "viridis" + assert cmap_params["extend"] == "both" def test_list_levels(self): data = self.data + 1 orig_levels = [0, 1, 2, 3, 4, 5] # vmin and vmax should be ignored if levels are explicitly provided - cmap_params = _determine_cmap_params( - data, levels=orig_levels, vmin=0, vmax=3) - assert cmap_params['vmin'] == 0 - assert cmap_params['vmax'] == 5 - assert cmap_params['cmap'].N == 5 - assert cmap_params['norm'].N == 6 + cmap_params = _determine_cmap_params(data, levels=orig_levels, vmin=0, vmax=3) + assert cmap_params["vmin"] == 0 + assert cmap_params["vmax"] == 5 + assert cmap_params["cmap"].N == 5 + assert cmap_params["norm"].N == 6 for wrap_levels in [list, np.array, pd.Index, DataArray]: - cmap_params = _determine_cmap_params( - data, levels=wrap_levels(orig_levels)) - assert_array_equal(cmap_params['levels'], orig_levels) + cmap_params = _determine_cmap_params(data, levels=wrap_levels(orig_levels)) + assert_array_equal(cmap_params["levels"], orig_levels) def test_divergentcontrol(self): neg = self.data - 0.1 @@ -618,91 +639,95 @@ def test_divergentcontrol(self): # Default with positive data will be a normal cmap cmap_params = _determine_cmap_params(pos) - assert cmap_params['vmin'] == 0 - assert cmap_params['vmax'] == 1 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == 0 + assert cmap_params["vmax"] == 1 + assert cmap_params["cmap"] == "viridis" # Default with negative data will be a divergent cmap cmap_params = _determine_cmap_params(neg) - assert cmap_params['vmin'] == -0.9 - assert cmap_params['vmax'] == 0.9 - assert cmap_params['cmap'] == "RdBu_r" + assert cmap_params["vmin"] == -0.9 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "RdBu_r" # Setting vmin or vmax should prevent this only if center is false cmap_params = _determine_cmap_params(neg, vmin=-0.1, center=False) - assert cmap_params['vmin'] == -0.1 - assert cmap_params['vmax'] == 0.9 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "viridis" cmap_params = _determine_cmap_params(neg, vmax=0.5, center=False) - assert cmap_params['vmin'] == -0.1 - assert cmap_params['vmax'] == 0.5 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.5 + assert cmap_params["cmap"] == "viridis" # Setting center=False too cmap_params = _determine_cmap_params(neg, center=False) - assert cmap_params['vmin'] == -0.1 - assert cmap_params['vmax'] == 0.9 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "viridis" # However, I should still be able to set center and have a div cmap cmap_params = _determine_cmap_params(neg, center=0) - assert cmap_params['vmin'] == -0.9 - assert cmap_params['vmax'] == 0.9 - assert cmap_params['cmap'] == "RdBu_r" + assert cmap_params["vmin"] == -0.9 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "RdBu_r" # Setting vmin or vmax alone will force symmetric bounds around center cmap_params = _determine_cmap_params(neg, vmin=-0.1) - assert cmap_params['vmin'] == -0.1 - assert cmap_params['vmax'] == 0.1 - assert cmap_params['cmap'] == "RdBu_r" + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.1 + assert cmap_params["cmap"] == "RdBu_r" cmap_params = _determine_cmap_params(neg, vmax=0.5) - assert cmap_params['vmin'] == -0.5 - assert cmap_params['vmax'] == 0.5 - assert cmap_params['cmap'] == "RdBu_r" + assert cmap_params["vmin"] == -0.5 + assert cmap_params["vmax"] == 0.5 + assert cmap_params["cmap"] == "RdBu_r" cmap_params = _determine_cmap_params(neg, vmax=0.6, center=0.1) - assert cmap_params['vmin'] == -0.4 - assert cmap_params['vmax'] == 0.6 - assert cmap_params['cmap'] == "RdBu_r" + assert cmap_params["vmin"] == -0.4 + assert cmap_params["vmax"] == 0.6 + assert cmap_params["cmap"] == "RdBu_r" # But this is only true if vmin or vmax are negative cmap_params = _determine_cmap_params(pos, vmin=-0.1) - assert cmap_params['vmin'] == -0.1 - assert cmap_params['vmax'] == 0.1 - assert cmap_params['cmap'] == "RdBu_r" + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.1 + assert cmap_params["cmap"] == "RdBu_r" cmap_params = _determine_cmap_params(pos, vmin=0.1) - assert cmap_params['vmin'] == 0.1 - assert cmap_params['vmax'] == 1 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == 0.1 + assert cmap_params["vmax"] == 1 + assert cmap_params["cmap"] == "viridis" cmap_params = _determine_cmap_params(pos, vmax=0.5) - assert cmap_params['vmin'] == 0 - assert cmap_params['vmax'] == 0.5 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == 0 + assert cmap_params["vmax"] == 0.5 + assert cmap_params["cmap"] == "viridis" # If both vmin and vmax are provided, output is non-divergent cmap_params = _determine_cmap_params(neg, vmin=-0.2, vmax=0.6) - assert cmap_params['vmin'] == -0.2 - assert cmap_params['vmax'] == 0.6 - assert cmap_params['cmap'] == "viridis" + assert cmap_params["vmin"] == -0.2 + assert cmap_params["vmax"] == 0.6 + assert cmap_params["cmap"] == "viridis" def test_norm_sets_vmin_vmax(self): vmin = self.data.min() vmax = self.data.max() - for norm, extend in zip([mpl.colors.LogNorm(), - mpl.colors.LogNorm(vmin + 1, vmax - 1), - mpl.colors.LogNorm(None, vmax - 1), - mpl.colors.LogNorm(vmin + 1, None)], - ['neither', 'both', 'max', 'min']): + for norm, extend in zip( + [ + mpl.colors.LogNorm(), + mpl.colors.LogNorm(vmin + 1, vmax - 1), + mpl.colors.LogNorm(None, vmax - 1), + mpl.colors.LogNorm(vmin + 1, None), + ], + ["neither", "both", "max", "min"], + ): test_min = vmin if norm.vmin is None else norm.vmin test_max = vmax if norm.vmax is None else norm.vmax cmap_params = _determine_cmap_params(self.data, norm=norm) - assert cmap_params['vmin'] == test_min - assert cmap_params['vmax'] == test_max - assert cmap_params['extend'] == extend - assert cmap_params['norm'] == norm + assert cmap_params["vmin"] == test_min + assert cmap_params["vmax"] == test_max + assert cmap_params["extend"] == extend + assert cmap_params["norm"] == norm @requires_matplotlib @@ -713,20 +738,22 @@ def setUp(self): y = np.arange(start=9, stop=-7, step=-3) xy = np.dstack(np.meshgrid(x, y)) distance = np.linalg.norm(xy, axis=2) - self.darray = DataArray(distance, list(zip(('y', 'x'), (y, x)))) + self.darray = DataArray(distance, list(zip(("y", "x"), (y, x)))) self.data_min = distance.min() self.data_max = distance.max() @pytest.mark.slow def test_recover_from_seaborn_jet_exception(self): - pal = _color_palette('jet', 4) + pal = _color_palette("jet", 4) assert type(pal) == np.ndarray assert len(pal) == 4 @pytest.mark.slow def test_build_discrete_cmap(self): - for (cmap, levels, extend, filled) in [('jet', [0, 1], 'both', False), - ('hot', [-4, 4], 'max', True)]: + for (cmap, levels, extend, filled) in [ + ("jet", [0, 1], "both", False), + ("hot", [-4, 4], "max", True), + ]: ncmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled) assert ncmap.N == len(levels) - 1 assert len(ncmap.colors) == len(levels) - 1 @@ -737,37 +764,40 @@ def test_build_discrete_cmap(self): if filled: assert ncmap.colorbar_extend == extend else: - assert ncmap.colorbar_extend == 'max' + assert ncmap.colorbar_extend == "max" @pytest.mark.slow def test_discrete_colormap_list_of_levels(self): - for extend, levels in [('max', [-1, 2, 4, 8, 10]), - ('both', [2, 5, 10, 11]), - ('neither', [0, 5, 10, 15]), - ('min', [2, 5, 10, 15])]: - for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: + for extend, levels in [ + ("max", [-1, 2, 4, 8, 10]), + ("both", [2, 5, 10, 11]), + ("neither", [0, 5, 10, 15]), + ("min", [2, 5, 10, 15]), + ]: + for kind in ["imshow", "pcolormesh", "contourf", "contour"]: primitive = getattr(self.darray.plot, kind)(levels=levels) assert_array_equal(levels, primitive.norm.boundaries) assert max(levels) == primitive.norm.vmax assert min(levels) == primitive.norm.vmin - if kind != 'contour': + if kind != "contour": assert extend == primitive.cmap.colorbar_extend else: - assert 'max' == primitive.cmap.colorbar_extend + assert "max" == primitive.cmap.colorbar_extend assert len(levels) - 1 == len(primitive.cmap.colors) @pytest.mark.slow def test_discrete_colormap_int_levels(self): for extend, levels, vmin, vmax, cmap in [ - ('neither', 7, None, None, None), - ('neither', 7, None, 20, mpl.cm.RdBu), - ('both', 7, 4, 8, None), - ('min', 10, 4, 15, None)]: - for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: + ("neither", 7, None, None, None), + ("neither", 7, None, 20, mpl.cm.RdBu), + ("both", 7, 4, 8, None), + ("min", 10, 4, 15, None), + ]: + for kind in ["imshow", "pcolormesh", "contourf", "contour"]: primitive = getattr(self.darray.plot, kind)( - levels=levels, vmin=vmin, vmax=vmax, cmap=cmap) - assert levels >= \ - len(primitive.norm.boundaries) - 1 + levels=levels, vmin=vmin, vmax=vmax, cmap=cmap + ) + assert levels >= len(primitive.norm.boundaries) - 1 if vmax is None: assert primitive.norm.vmax >= self.data_max else: @@ -776,10 +806,10 @@ def test_discrete_colormap_int_levels(self): assert primitive.norm.vmin <= self.data_min else: assert primitive.norm.vmin <= vmin - if kind != 'contour': + if kind != "contour": assert extend == primitive.cmap.colorbar_extend else: - assert 'max' == primitive.cmap.colorbar_extend + assert "max" == primitive.cmap.colorbar_extend assert levels >= len(primitive.cmap.colors) def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): @@ -804,48 +834,49 @@ class Common2dMixin: @pytest.fixture(autouse=True) def setUp(self): - da = DataArray(easy_array((10, 15), start=-1), - dims=['y', 'x'], - coords={'y': np.arange(10), - 'x': np.arange(15)}) + da = DataArray( + easy_array((10, 15), start=-1), + dims=["y", "x"], + coords={"y": np.arange(10), "x": np.arange(15)}, + ) # add 2d coords - ds = da.to_dataset(name='testvar') + ds = da.to_dataset(name="testvar") x, y = np.meshgrid(da.x.values, da.y.values) - ds['x2d'] = DataArray(x, dims=['y', 'x']) - ds['y2d'] = DataArray(y, dims=['y', 'x']) - ds = ds.set_coords(['x2d', 'y2d']) + ds["x2d"] = DataArray(x, dims=["y", "x"]) + ds["y2d"] = DataArray(y, dims=["y", "x"]) + ds = ds.set_coords(["x2d", "y2d"]) # set darray and plot method self.darray = ds.testvar # Add CF-compliant metadata - self.darray.attrs['long_name'] = 'a_long_name' - self.darray.attrs['units'] = 'a_units' - self.darray.x.attrs['long_name'] = 'x_long_name' - self.darray.x.attrs['units'] = 'x_units' - self.darray.y.attrs['long_name'] = 'y_long_name' - self.darray.y.attrs['units'] = 'y_units' + self.darray.attrs["long_name"] = "a_long_name" + self.darray.attrs["units"] = "a_units" + self.darray.x.attrs["long_name"] = "x_long_name" + self.darray.x.attrs["units"] = "x_units" + self.darray.y.attrs["long_name"] = "y_long_name" + self.darray.y.attrs["units"] = "y_units" self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) def test_label_names(self): self.plotmethod() - assert 'x_long_name [x_units]' == plt.gca().get_xlabel() - assert 'y_long_name [y_units]' == plt.gca().get_ylabel() + assert "x_long_name [x_units]" == plt.gca().get_xlabel() + assert "y_long_name [y_units]" == plt.gca().get_ylabel() def test_1d_raises_valueerror(self): - with raises_regex(ValueError, r'DataArray must be 2d'): + with raises_regex(ValueError, r"DataArray must be 2d"): self.plotfunc(self.darray[0, :]) def test_3d_raises_valueerror(self): a = DataArray(easy_array((2, 3, 4))) - if self.plotfunc.__name__ == 'imshow': + if self.plotfunc.__name__ == "imshow": pytest.skip() - with raises_regex(ValueError, r'DataArray must be 2d'): + with raises_regex(ValueError, r"DataArray must be 2d"): self.plotfunc(a) def test_nonnumeric_index_raises_typeerror(self): - a = DataArray(easy_array((3, 2)), coords=[['a', 'b', 'c'], ['d', 'e']]) - with raises_regex(TypeError, r'[Pp]lot'): + a = DataArray(easy_array((3, 2)), coords=[["a", "b", "c"], ["d", "e"]]) + with raises_regex(TypeError, r"[Pp]lot"): self.plotfunc(a) def test_can_pass_in_axis(self): @@ -855,15 +886,13 @@ def test_xyincrease_defaults(self): # With default settings the axis must be ordered regardless # of the coords order. - self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], - [1, 2]])) + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], [1, 2]])) bounds = plt.gca().get_ylim() assert bounds[0] < bounds[1] bounds = plt.gca().get_xlim() assert bounds[0] < bounds[1] # Inverted coords - self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], - [2, 1]])) + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], [2, 1]])) bounds = plt.gca().get_ylim() assert bounds[0] < bounds[1] bounds = plt.gca().get_xlim() @@ -884,10 +913,9 @@ def test_xyincrease_true_changes_axes(self): assert all(abs(x) < 1 for x in diffs) def test_x_ticks_are_rotated_for_time(self): - time = pd.date_range('2000-01-01', '2000-01-10') - a = DataArray( - np.random.randn(2, len(time)), [('xx', [1, 2]), ('t', time)]) - a.plot(x='t') + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + a.plot(x="t") rotation = plt.gca().get_xticklabels()[0].get_rotation() assert rotation != 0 @@ -900,41 +928,41 @@ def test_plot_nans(self): clim2 = self.plotfunc(x2).get_clim() assert clim1 == clim2 - @pytest.mark.filterwarnings('ignore::UserWarning') - @pytest.mark.filterwarnings('ignore:invalid value encountered') + @pytest.mark.filterwarnings("ignore::UserWarning") + @pytest.mark.filterwarnings("ignore:invalid value encountered") def test_can_plot_all_nans(self): # regression test for issue #1780 self.plotfunc(DataArray(np.full((2, 2), np.nan))) - @pytest.mark.filterwarnings('ignore: Attempting to set') + @pytest.mark.filterwarnings("ignore: Attempting to set") def test_can_plot_axis_size_one(self): - if self.plotfunc.__name__ not in ('contour', 'contourf'): + if self.plotfunc.__name__ not in ("contour", "contourf"): self.plotfunc(DataArray(np.ones((1, 1)))) def test_disallows_rgb_arg(self): with pytest.raises(ValueError): # Always invalid for most plots. Invalid for imshow with 2D data. - self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None') + self.plotfunc(DataArray(np.ones((2, 2))), rgb="not None") def test_viridis_cmap(self): - cmap_name = self.plotmethod(cmap='viridis').get_cmap().name - assert 'viridis' == cmap_name + cmap_name = self.plotmethod(cmap="viridis").get_cmap().name + assert "viridis" == cmap_name def test_default_cmap(self): cmap_name = self.plotmethod().get_cmap().name - assert 'RdBu_r' == cmap_name + assert "RdBu_r" == cmap_name cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name - assert 'viridis' == cmap_name + assert "viridis" == cmap_name @requires_seaborn def test_seaborn_palette_as_cmap(self): - cmap_name = self.plotmethod(levels=2, cmap='husl').get_cmap().name - assert 'husl' == cmap_name + cmap_name = self.plotmethod(levels=2, cmap="husl").get_cmap().name + assert "husl" == cmap_name def test_can_change_default_cmap(self): - cmap_name = self.plotmethod(cmap='Blues').get_cmap().name - assert 'Blues' == cmap_name + cmap_name = self.plotmethod(cmap="Blues").get_cmap().name + assert "Blues" == cmap_name def test_diverging_color_limits(self): artist = self.plotmethod() @@ -942,209 +970,203 @@ def test_diverging_color_limits(self): assert round(abs(-vmin - vmax), 7) == 0 def test_xy_strings(self): - self.plotmethod('y', 'x') + self.plotmethod("y", "x") ax = plt.gca() - assert 'y_long_name [y_units]' == ax.get_xlabel() - assert 'x_long_name [x_units]' == ax.get_ylabel() + assert "y_long_name [y_units]" == ax.get_xlabel() + assert "x_long_name [x_units]" == ax.get_ylabel() def test_positional_coord_string(self): - self.plotmethod(y='x') + self.plotmethod(y="x") ax = plt.gca() - assert 'x_long_name [x_units]' == ax.get_ylabel() - assert 'y_long_name [y_units]' == ax.get_xlabel() + assert "x_long_name [x_units]" == ax.get_ylabel() + assert "y_long_name [y_units]" == ax.get_xlabel() - self.plotmethod(x='x') + self.plotmethod(x="x") ax = plt.gca() - assert 'x_long_name [x_units]' == ax.get_xlabel() - assert 'y_long_name [y_units]' == ax.get_ylabel() + assert "x_long_name [x_units]" == ax.get_xlabel() + assert "y_long_name [y_units]" == ax.get_ylabel() def test_bad_x_string_exception(self): - with raises_regex(ValueError, 'x and y must be coordinate variables'): - self.plotmethod('not_a_real_dim', 'y') - with raises_regex(ValueError, - 'x must be a dimension name if y is not supplied'): - self.plotmethod(x='not_a_real_dim') - with raises_regex(ValueError, - 'y must be a dimension name if x is not supplied'): - self.plotmethod(y='not_a_real_dim') - self.darray.coords['z'] = 100 + with raises_regex(ValueError, "x and y must be coordinate variables"): + self.plotmethod("not_a_real_dim", "y") + with raises_regex( + ValueError, "x must be a dimension name if y is not supplied" + ): + self.plotmethod(x="not_a_real_dim") + with raises_regex( + ValueError, "y must be a dimension name if x is not supplied" + ): + self.plotmethod(y="not_a_real_dim") + self.darray.coords["z"] = 100 def test_coord_strings(self): # 1d coords (same as dims) - assert {'x', 'y'} == set(self.darray.dims) - self.plotmethod(y='y', x='x') + assert {"x", "y"} == set(self.darray.dims) + self.plotmethod(y="y", x="x") def test_non_linked_coords(self): # plot with coordinate names that are not dimensions - self.darray.coords['newy'] = self.darray.y + 150 + self.darray.coords["newy"] = self.darray.y + 150 # Normal case, without transpose - self.plotfunc(self.darray, x='x', y='newy') + self.plotfunc(self.darray, x="x", y="newy") ax = plt.gca() - assert 'x_long_name [x_units]' == ax.get_xlabel() - assert 'newy' == ax.get_ylabel() + assert "x_long_name [x_units]" == ax.get_xlabel() + assert "newy" == ax.get_ylabel() # ax limits might change between plotfuncs # simply ensure that these high coords were passed over - assert np.min(ax.get_ylim()) > 100. + assert np.min(ax.get_ylim()) > 100.0 def test_non_linked_coords_transpose(self): # plot with coordinate names that are not dimensions, # and with transposed y and x axes # This used to raise an error with pcolormesh and contour # https://github.com/pydata/xarray/issues/788 - self.darray.coords['newy'] = self.darray.y + 150 - self.plotfunc(self.darray, x='newy', y='x') + self.darray.coords["newy"] = self.darray.y + 150 + self.plotfunc(self.darray, x="newy", y="x") ax = plt.gca() - assert 'newy' == ax.get_xlabel() - assert 'x_long_name [x_units]' == ax.get_ylabel() + assert "newy" == ax.get_xlabel() + assert "x_long_name [x_units]" == ax.get_ylabel() # ax limits might change between plotfuncs # simply ensure that these high coords were passed over - assert np.min(ax.get_xlim()) > 100. + assert np.min(ax.get_xlim()) > 100.0 def test_default_title(self): - a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c']) - a.coords['c'] = [0, 1] - a.coords['d'] = 'foo' + a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"]) + a.coords["c"] = [0, 1] + a.coords["d"] = "foo" self.plotfunc(a.isel(c=1)) title = plt.gca().get_title() - assert 'c = 1, d = foo' == title or 'd = foo, c = 1' == title + assert "c = 1, d = foo" == title or "d = foo, c = 1" == title def test_colorbar_default_label(self): self.plotmethod(add_colorbar=True) - assert ('a_long_name [a_units]' in text_in_fig()) + assert "a_long_name [a_units]" in text_in_fig() def test_no_labels(self): - self.darray.name = 'testvar' - self.darray.attrs['units'] = 'test_units' + self.darray.name = "testvar" + self.darray.attrs["units"] = "test_units" self.plotmethod(add_labels=False) alltxt = text_in_fig() - for string in ['x_long_name [x_units]', - 'y_long_name [y_units]', - 'testvar [test_units]']: + for string in [ + "x_long_name [x_units]", + "y_long_name [y_units]", + "testvar [test_units]", + ]: assert string not in alltxt def test_colorbar_kwargs(self): # replace label - self.darray.attrs.pop('long_name') - self.darray.attrs['units'] = 'test_units' + self.darray.attrs.pop("long_name") + self.darray.attrs["units"] = "test_units" # check default colorbar label self.plotmethod(add_colorbar=True) alltxt = text_in_fig() - assert 'testvar [test_units]' in alltxt - self.darray.attrs.pop('units') + assert "testvar [test_units]" in alltxt + self.darray.attrs.pop("units") - self.darray.name = 'testvar' - self.plotmethod(add_colorbar=True, cbar_kwargs={'label': 'MyLabel'}) + self.darray.name = "testvar" + self.plotmethod(add_colorbar=True, cbar_kwargs={"label": "MyLabel"}) alltxt = text_in_fig() - assert 'MyLabel' in alltxt - assert 'testvar' not in alltxt + assert "MyLabel" in alltxt + assert "testvar" not in alltxt # you can use anything accepted by the dict constructor as well - self.plotmethod( - add_colorbar=True, cbar_kwargs=(('label', 'MyLabel'), )) + self.plotmethod(add_colorbar=True, cbar_kwargs=(("label", "MyLabel"),)) alltxt = text_in_fig() - assert 'MyLabel' in alltxt - assert 'testvar' not in alltxt + assert "MyLabel" in alltxt + assert "testvar" not in alltxt # change cbar ax fig, (ax, cax) = plt.subplots(1, 2) self.plotmethod( - ax=ax, - cbar_ax=cax, - add_colorbar=True, - cbar_kwargs={ - 'label': 'MyBar' - }) + ax=ax, cbar_ax=cax, add_colorbar=True, cbar_kwargs={"label": "MyBar"} + ) assert ax.has_data() assert cax.has_data() alltxt = text_in_fig() - assert 'MyBar' in alltxt - assert 'testvar' not in alltxt + assert "MyBar" in alltxt + assert "testvar" not in alltxt # note that there are two ways to achieve this fig, (ax, cax) = plt.subplots(1, 2) self.plotmethod( - ax=ax, - add_colorbar=True, - cbar_kwargs={ - 'label': 'MyBar', - 'cax': cax - }) + ax=ax, add_colorbar=True, cbar_kwargs={"label": "MyBar", "cax": cax} + ) assert ax.has_data() assert cax.has_data() alltxt = text_in_fig() - assert 'MyBar' in alltxt - assert 'testvar' not in alltxt + assert "MyBar" in alltxt + assert "testvar" not in alltxt # see that no colorbar is respected self.plotmethod(add_colorbar=False) - assert 'testvar' not in text_in_fig() + assert "testvar" not in text_in_fig() # check that error is raised pytest.raises( ValueError, self.plotmethod, add_colorbar=False, - cbar_kwargs={ - 'label': 'label' - }) + cbar_kwargs={"label": "label"}, + ) def test_verbose_facetgrid(self): a = easy_array((10, 15, 3)) - d = DataArray(a, dims=['y', 'x', 'z']) - g = xplt.FacetGrid(d, col='z') - g.map_dataarray(self.plotfunc, 'x', 'y') + d = DataArray(a, dims=["y", "x", "z"]) + g = xplt.FacetGrid(d, col="z") + g.map_dataarray(self.plotfunc, "x", "y") for ax in g.axes.flat: assert ax.has_data() def test_2d_function_and_method_signature_same(self): func_sig = inspect.getcallargs(self.plotfunc, self.darray) method_sig = inspect.getcallargs(self.plotmethod) - del method_sig['_PlotMethods_obj'] - del func_sig['darray'] + del method_sig["_PlotMethods_obj"] + del func_sig["darray"] assert func_sig == method_sig - @pytest.mark.filterwarnings('ignore:tight_layout cannot') + @pytest.mark.filterwarnings("ignore:tight_layout cannot") def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) - d = DataArray(a, dims=['y', 'x', 'z']) - g = self.plotfunc(d, x='x', y='y', col='z', col_wrap=2) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) assert_array_equal(g.axes.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axes): assert ax.has_data() if x == 0: - assert 'y' == ax.get_ylabel() + assert "y" == ax.get_ylabel() else: - assert '' == ax.get_ylabel() + assert "" == ax.get_ylabel() if y == 1: - assert 'x' == ax.get_xlabel() + assert "x" == ax.get_xlabel() else: - assert '' == ax.get_xlabel() + assert "" == ax.get_xlabel() # Infering labels - g = self.plotfunc(d, col='z', col_wrap=2) + g = self.plotfunc(d, col="z", col_wrap=2) assert_array_equal(g.axes.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axes): assert ax.has_data() if x == 0: - assert 'y' == ax.get_ylabel() + assert "y" == ax.get_ylabel() else: - assert '' == ax.get_ylabel() + assert "" == ax.get_ylabel() if y == 1: - assert 'x' == ax.get_xlabel() + assert "x" == ax.get_xlabel() else: - assert '' == ax.get_xlabel() + assert "" == ax.get_xlabel() - @pytest.mark.filterwarnings('ignore:tight_layout cannot') + @pytest.mark.filterwarnings("ignore:tight_layout cannot") def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) - d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) - g = self.plotfunc(d, x='x', y='y', col='columns', row='rows') + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = self.plotfunc(d, x="x", y="y", col="columns", row="rows") assert_array_equal(g.axes.shape, [3, 2]) for ax in g.axes.flat: assert ax.has_data() - @pytest.mark.filterwarnings('ignore:This figure includes') + @pytest.mark.filterwarnings("ignore:This figure includes") def test_facetgrid_map_only_appends_mappables(self): a = easy_array((10, 15, 2, 3)) - d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) - g = self.plotfunc(d, x='x', y='y', col='columns', row='rows') + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = self.plotfunc(d, x="x", y="y", col="columns", row="rows") expected = g._mappables @@ -1155,9 +1177,9 @@ def test_facetgrid_map_only_appends_mappables(self): def test_facetgrid_cmap(self): # Regression test for GH592 - data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12)) - d = DataArray(data, dims=['x', 'y', 'time']) - fg = d.plot.pcolormesh(col='time') + data = np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12) + d = DataArray(data, dims=["x", "y", "time"]) + fg = d.plot.pcolormesh(col="time") # check that all color limits are the same assert len({m.get_clim() for m in fg._mappables}) == 1 # check that all colormaps are the same @@ -1165,30 +1187,36 @@ def test_facetgrid_cmap(self): def test_facetgrid_cbar_kwargs(self): a = easy_array((10, 15, 2, 3)) - d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) - g = self.plotfunc(d, x='x', y='y', col='columns', row='rows', - cbar_kwargs={'label': 'test_label'}) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = self.plotfunc( + d, + x="x", + y="y", + col="columns", + row="rows", + cbar_kwargs={"label": "test_label"}, + ) # catch contour case - if hasattr(g, 'cbar'): - assert g.cbar._label == 'test_label' + if hasattr(g, "cbar"): + assert g.cbar._label == "test_label" def test_facetgrid_no_cbar_ax(self): a = easy_array((10, 15, 2, 3)) - d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) with pytest.raises(ValueError): - self.plotfunc(d, x='x', y='y', col='columns', row='rows', - cbar_ax=1) + self.plotfunc(d, x="x", y="y", col="columns", row="rows", cbar_ax=1) def test_cmap_and_color_both(self): with pytest.raises(ValueError): - self.plotmethod(colors='k', cmap='RdBu') + self.plotmethod(colors="k", cmap="RdBu") def test_2d_coord_with_interval(self): for dim in self.darray.dims: - gp = self.darray.groupby_bins( - dim, range(15), restore_coord_dims=True).mean(dim) - for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: + gp = self.darray.groupby_bins(dim, range(15), restore_coord_dims=True).mean( + dim + ) + for kind in ["imshow", "pcolormesh", "contourf", "contour"]: getattr(gp.plot, kind)() def test_colormap_error_norm_and_vmin_vmax(self): @@ -1219,36 +1247,36 @@ def test_primitive_artist_returned(self): @pytest.mark.slow def test_extend(self): artist = self.plotmethod() - assert artist.extend == 'neither' + assert artist.extend == "neither" self.darray[0, 0] = -100 self.darray[-1, -1] = 100 artist = self.plotmethod(robust=True) - assert artist.extend == 'both' + assert artist.extend == "both" self.darray[0, 0] = 0 self.darray[-1, -1] = 0 artist = self.plotmethod(vmin=-0, vmax=10) - assert artist.extend == 'min' + assert artist.extend == "min" artist = self.plotmethod(vmin=-10, vmax=0) - assert artist.extend == 'max' + assert artist.extend == "max" @pytest.mark.slow def test_2d_coord_names(self): - self.plotmethod(x='x2d', y='y2d') + self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() - assert 'x2d' == ax.get_xlabel() - assert 'y2d' == ax.get_ylabel() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() @pytest.mark.slow def test_levels(self): artist = self.plotmethod(levels=[-0.5, -0.4, 0.1]) - assert artist.extend == 'both' + assert artist.extend == "both" artist = self.plotmethod(levels=3) - assert artist.extend == 'neither' + assert artist.extend == "neither" @pytest.mark.slow @@ -1263,38 +1291,35 @@ def _color_as_tuple(c): return tuple(c[:3]) # with single color, we don't want rgb array - artist = self.plotmethod(colors='k') - assert artist.cmap.colors[0] == 'k' + artist = self.plotmethod(colors="k") + assert artist.cmap.colors[0] == "k" - artist = self.plotmethod(colors=['k', 'b']) - assert (_color_as_tuple(artist.cmap.colors[1]) == - (0.0, 0.0, 1.0)) + artist = self.plotmethod(colors=["k", "b"]) + assert _color_as_tuple(artist.cmap.colors[1]) == (0.0, 0.0, 1.0) artist = self.darray.plot.contour( - levels=[-0.5, 0., 0.5, 1.], colors=['k', 'r', 'w', 'b']) - assert (_color_as_tuple(artist.cmap.colors[1]) == - (1.0, 0.0, 0.0)) - assert (_color_as_tuple(artist.cmap.colors[2]) == - (1.0, 1.0, 1.0)) + levels=[-0.5, 0.0, 0.5, 1.0], colors=["k", "r", "w", "b"] + ) + assert _color_as_tuple(artist.cmap.colors[1]) == (1.0, 0.0, 0.0) + assert _color_as_tuple(artist.cmap.colors[2]) == (1.0, 1.0, 1.0) # the last color is now under "over" - assert (_color_as_tuple(artist.cmap._rgba_over) == - (0.0, 0.0, 1.0)) + assert _color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) def test_cmap_and_color_both(self): with pytest.raises(ValueError): - self.plotmethod(colors='k', cmap='RdBu') + self.plotmethod(colors="k", cmap="RdBu") def list_of_colors_in_cmap_deprecated(self): with pytest.raises(Exception): - self.plotmethod(cmap=['k', 'b']) + self.plotmethod(cmap=["k", "b"]) @pytest.mark.slow def test_2d_coord_names(self): - self.plotmethod(x='x2d', y='y2d') + self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() - assert 'x2d' == ax.get_xlabel() - assert 'y2d' == ax.get_ylabel() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() def test_single_level(self): # this used to raise an error, but not anymore since @@ -1317,18 +1342,18 @@ def test_everything_plotted(self): @pytest.mark.slow def test_2d_coord_names(self): - self.plotmethod(x='x2d', y='y2d') + self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() - assert 'x2d' == ax.get_xlabel() - assert 'y2d' == ax.get_ylabel() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() def test_dont_infer_interval_breaks_for_cartopy(self): # Regression for GH 781 ax = plt.gca() # Simulate a Cartopy Axis - setattr(ax, 'projection', True) - artist = self.plotmethod(x='x2d', y='y2d', ax=ax) + setattr(ax, "projection", True) + artist = self.plotmethod(x="x2d", y="y2d", ax=ax) assert isinstance(artist, mpl.collections.QuadMesh) # Let cartopy handle the axis limits and artist size assert artist.get_array().size <= self.darray.size @@ -1352,17 +1377,17 @@ def test_xy_pixel_centered(self): def test_default_aspect_is_auto(self): self.darray.plot.imshow() - assert 'auto' == plt.gca().get_aspect() + assert "auto" == plt.gca().get_aspect() @pytest.mark.slow def test_cannot_change_mpl_aspect(self): - with raises_regex(ValueError, 'not available in xarray'): - self.darray.plot.imshow(aspect='equal') + with raises_regex(ValueError, "not available in xarray"): + self.darray.plot.imshow(aspect="equal") # with numbers we fall back to fig control self.darray.plot.imshow(size=5, aspect=2) - assert 'auto' == plt.gca().get_aspect() + assert "auto" == plt.gca().get_aspect() assert tuple(plt.gcf().get_size_inches()) == (10, 5) @pytest.mark.slow @@ -1374,65 +1399,62 @@ def test_primitive_artist_returned(self): @requires_seaborn def test_seaborn_palette_needs_levels(self): with pytest.raises(ValueError): - self.plotmethod(cmap='husl') + self.plotmethod(cmap="husl") def test_2d_coord_names(self): - with raises_regex(ValueError, 'requires 1D coordinates'): - self.plotmethod(x='x2d', y='y2d') + with raises_regex(ValueError, "requires 1D coordinates"): + self.plotmethod(x="x2d", y="y2d") def test_plot_rgb_image(self): DataArray( - easy_array((10, 15, 3), start=0), - dims=['y', 'x', 'band'], + easy_array((10, 15, 3), start=0), dims=["y", "x", "band"] ).plot.imshow() assert 0 == len(find_possible_colorbars()) def test_plot_rgb_image_explicit(self): DataArray( - easy_array((10, 15, 3), start=0), - dims=['y', 'x', 'band'], - ).plot.imshow( - y='y', x='x', rgb='band') + easy_array((10, 15, 3), start=0), dims=["y", "x", "band"] + ).plot.imshow(y="y", x="x", rgb="band") assert 0 == len(find_possible_colorbars()) def test_plot_rgb_faceted(self): DataArray( - easy_array((2, 2, 10, 15, 3), start=0), - dims=['a', 'b', 'y', 'x', 'band'], - ).plot.imshow( - row='a', col='b') + easy_array((2, 2, 10, 15, 3), start=0), dims=["a", "b", "y", "x", "band"] + ).plot.imshow(row="a", col="b") assert 0 == len(find_possible_colorbars()) def test_plot_rgba_image_transposed(self): # We can handle the color axis being in any position DataArray( - easy_array((4, 10, 15), start=0), - dims=['band', 'y', 'x'], + easy_array((4, 10, 15), start=0), dims=["band", "y", "x"] ).plot.imshow() def test_warns_ambigious_dim(self): - arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band']) + arr = DataArray(easy_array((3, 3, 3)), dims=["y", "x", "band"]) with pytest.warns(UserWarning): arr.plot.imshow() # but doesn't warn if dimensions specified - arr.plot.imshow(rgb='band') - arr.plot.imshow(x='x', y='y') + arr.plot.imshow(rgb="band") + arr.plot.imshow(x="x", y="y") def test_rgb_errors_too_many_dims(self): - arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band']) + arr = DataArray(easy_array((3, 3, 3, 3)), dims=["y", "x", "z", "band"]) with pytest.raises(ValueError): - arr.plot.imshow(rgb='band') + arr.plot.imshow(rgb="band") def test_rgb_errors_bad_dim_sizes(self): - arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band']) + arr = DataArray(easy_array((5, 5, 5)), dims=["y", "x", "band"]) with pytest.raises(ValueError): - arr.plot.imshow(rgb='band') + arr.plot.imshow(rgb="band") def test_normalize_rgb_imshow(self): for kwargs in ( - dict(vmin=-1), dict(vmax=2), - dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0), - dict(vmin=0, robust=True), dict(vmax=-1, robust=True), + dict(vmin=-1), + dict(vmax=2), + dict(vmin=-1, vmax=1), + dict(vmin=0, vmax=0), + dict(vmin=0, robust=True), + dict(vmax=-1, robust=True), ): da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) arr = da.plot.imshow(**kwargs).get_array() @@ -1449,13 +1471,13 @@ def test_normalize_rgb_one_arg_error(self): da.plot.imshow(**kwargs) def test_imshow_rgb_values_in_valid_range(self): - da = DataArray(np.arange(75, dtype='uint8').reshape((5, 5, 3))) + da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() assert out.dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha - @pytest.mark.filterwarnings('ignore:Several dimensions of this array') + @pytest.mark.filterwarnings("ignore:Several dimensions of this array") def test_regression_rgb_imshow_dim_size_one(self): # Regression: https://github.com/pydata/xarray/issues/1966 da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) @@ -1463,12 +1485,12 @@ def test_regression_rgb_imshow_dim_size_one(self): def test_origin_overrides_xyincrease(self): da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) - da.plot.imshow(origin='upper') + da.plot.imshow(origin="upper") assert plt.xlim()[0] < 0 assert plt.ylim()[1] < 0 plt.clf() - da.plot.imshow(origin='lower') + da.plot.imshow(origin="lower") assert plt.xlim()[0] < 0 assert plt.ylim()[0] < 0 @@ -1477,46 +1499,43 @@ class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): d = easy_array((10, 15, 3)) - self.darray = DataArray( - d, dims=['y', 'x', 'z'], coords={ - 'z': ['a', 'b', 'c'] - }) - self.g = xplt.FacetGrid(self.darray, col='z') + self.darray = DataArray(d, dims=["y", "x", "z"], coords={"z": ["a", "b", "c"]}) + self.g = xplt.FacetGrid(self.darray, col="z") @pytest.mark.slow def test_no_args(self): - self.g.map_dataarray(xplt.contourf, 'x', 'y') + self.g.map_dataarray(xplt.contourf, "x", "y") # Don't want colorbar labeled with 'None' alltxt = text_in_fig() - assert 'None' not in alltxt + assert "None" not in alltxt for ax in self.g.axes.flat: assert ax.has_data() @pytest.mark.slow def test_names_appear_somewhere(self): - self.darray.name = 'testvar' - self.g.map_dataarray(xplt.contourf, 'x', 'y') - for k, ax in zip('abc', self.g.axes.flat): - assert 'z = {}'.format(k) == ax.get_title() + self.darray.name = "testvar" + self.g.map_dataarray(xplt.contourf, "x", "y") + for k, ax in zip("abc", self.g.axes.flat): + assert "z = {}".format(k) == ax.get_title() alltxt = text_in_fig() assert self.darray.name in alltxt - for label in ['x', 'y']: + for label in ["x", "y"]: assert label in alltxt @pytest.mark.slow def test_text_not_super_long(self): - self.darray.coords['z'] = [100 * letter for letter in 'abc'] - g = xplt.FacetGrid(self.darray, col='z') - g.map_dataarray(xplt.contour, 'x', 'y') + self.darray.coords["z"] = [100 * letter for letter in "abc"] + g = xplt.FacetGrid(self.darray, col="z") + g.map_dataarray(xplt.contour, "x", "y") alltxt = text_in_fig() maxlen = max(len(txt) for txt in alltxt) assert maxlen < 50 t0 = g.axes[0, 0].get_title() - assert t0.endswith('...') + assert t0.endswith("...") @pytest.mark.slow def test_colorbar(self): @@ -1524,7 +1543,7 @@ def test_colorbar(self): vmax = self.darray.values.max() expected = np.array((vmin, vmax)) - self.g.map_dataarray(xplt.imshow, 'x', 'y') + self.g.map_dataarray(xplt.imshow, "x", "y") for image in plt.gcf().findobj(mpl.image.AxesImage): clim = np.array(image.get_clim()) @@ -1534,8 +1553,8 @@ def test_colorbar(self): @pytest.mark.slow def test_empty_cell(self): - g = xplt.FacetGrid(self.darray, col='z', col_wrap=2) - g.map_dataarray(xplt.imshow, 'x', 'y') + g = xplt.FacetGrid(self.darray, col="z", col_wrap=2) + g.map_dataarray(xplt.imshow, "x", "y") bottomright = g.axes[-1, -1] assert not bottomright.has_data() @@ -1543,12 +1562,12 @@ def test_empty_cell(self): @pytest.mark.slow def test_norow_nocol_error(self): - with raises_regex(ValueError, r'[Rr]ow'): + with raises_regex(ValueError, r"[Rr]ow"): xplt.FacetGrid(self.darray) @pytest.mark.slow def test_groups(self): - self.g.map_dataarray(xplt.imshow, 'x', 'y') + self.g.map_dataarray(xplt.imshow, "x", "y") upperleft_dict = self.g.name_dicts[0, 0] upperleft_array = self.darray.loc[upperleft_dict] z0 = self.darray.isel(z=0) @@ -1557,25 +1576,25 @@ def test_groups(self): @pytest.mark.slow def test_float_index(self): - self.darray.coords['z'] = [0.1, 0.2, 0.4] - g = xplt.FacetGrid(self.darray, col='z') - g.map_dataarray(xplt.imshow, 'x', 'y') + self.darray.coords["z"] = [0.1, 0.2, 0.4] + g = xplt.FacetGrid(self.darray, col="z") + g.map_dataarray(xplt.imshow, "x", "y") @pytest.mark.slow def test_nonunique_index_error(self): - self.darray.coords['z'] = [0.1, 0.2, 0.2] - with raises_regex(ValueError, r'[Uu]nique'): - xplt.FacetGrid(self.darray, col='z') + self.darray.coords["z"] = [0.1, 0.2, 0.2] + with raises_regex(ValueError, r"[Uu]nique"): + xplt.FacetGrid(self.darray, col="z") @pytest.mark.slow def test_robust(self): z = np.zeros((20, 20, 2)) - darray = DataArray(z, dims=['y', 'x', 'z']) + darray = DataArray(z, dims=["y", "x", "z"]) darray[:, :, 1] = 1 darray[2, 0, 0] = -1000 darray[3, 0, 0] = 1000 - g = xplt.FacetGrid(darray, col='z') - g.map_dataarray(xplt.imshow, 'x', 'y', robust=True) + g = xplt.FacetGrid(darray, col="z") + g.map_dataarray(xplt.imshow, "x", "y", robust=True) # Color limits should be 0, 1 # The largest number displayed in the figure should be less than 21 @@ -1593,7 +1612,7 @@ def test_robust(self): def test_can_set_vmin_vmax(self): vmin, vmax = 50.0, 1000.0 expected = np.array((vmin, vmax)) - self.g.map_dataarray(xplt.imshow, 'x', 'y', vmin=vmin, vmax=vmax) + self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) for image in plt.gcf().findobj(mpl.image.AxesImage): clim = np.array(image.get_clim()) @@ -1602,7 +1621,7 @@ def test_can_set_vmin_vmax(self): @pytest.mark.slow def test_can_set_norm(self): norm = mpl.colors.SymLogNorm(0.1) - self.g.map_dataarray(xplt.imshow, 'x', 'y', norm=norm) + self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) for image in plt.gcf().findobj(mpl.image.AxesImage): assert image.norm is norm @@ -1611,29 +1630,29 @@ def test_figure_size(self): assert_array_equal(self.g.fig.get_size_inches(), (10, 3)) - g = xplt.FacetGrid(self.darray, col='z', size=6) + g = xplt.FacetGrid(self.darray, col="z", size=6) assert_array_equal(g.fig.get_size_inches(), (19, 6)) - g = self.darray.plot.imshow(col='z', size=6) + g = self.darray.plot.imshow(col="z", size=6) assert_array_equal(g.fig.get_size_inches(), (19, 6)) - g = xplt.FacetGrid(self.darray, col='z', size=4, aspect=0.5) + g = xplt.FacetGrid(self.darray, col="z", size=4, aspect=0.5) assert_array_equal(g.fig.get_size_inches(), (7, 4)) - g = xplt.FacetGrid(self.darray, col='z', figsize=(9, 4)) + g = xplt.FacetGrid(self.darray, col="z", figsize=(9, 4)) assert_array_equal(g.fig.get_size_inches(), (9, 4)) with raises_regex(ValueError, "cannot provide both"): - g = xplt.plot(self.darray, row=2, col='z', figsize=(6, 4), size=6) + g = xplt.plot(self.darray, row=2, col="z", figsize=(6, 4), size=6) with raises_regex(ValueError, "Can't use"): - g = xplt.plot(self.darray, row=2, col='z', ax=plt.gca(), size=6) + g = xplt.plot(self.darray, row=2, col="z", ax=plt.gca(), size=6) @pytest.mark.slow def test_num_ticks(self): nticks = 99 maxticks = nticks + 1 - self.g.map_dataarray(xplt.imshow, 'x', 'y') + self.g.map_dataarray(xplt.imshow, "x", "y") self.g.set_ticks(max_xticks=nticks, max_yticks=nticks) for ax in self.g.axes.flat: @@ -1647,142 +1666,144 @@ def test_num_ticks(self): @pytest.mark.slow def test_map(self): assert self.g._finalized is False - self.g.map(plt.contourf, 'x', 'y', Ellipsis) + self.g.map(plt.contourf, "x", "y", Ellipsis) assert self.g._finalized is True self.g.map(lambda: None) @pytest.mark.slow def test_map_dataset(self): - g = xplt.FacetGrid(self.darray.to_dataset(name='foo'), col='z') - g.map(plt.contourf, 'x', 'y', 'foo') + g = xplt.FacetGrid(self.darray.to_dataset(name="foo"), col="z") + g.map(plt.contourf, "x", "y", "foo") alltxt = text_in_fig() - for label in ['x', 'y']: + for label in ["x", "y"]: assert label in alltxt # everything has a label - assert 'None' not in alltxt + assert "None" not in alltxt # colorbar can't be inferred automatically - assert 'foo' not in alltxt + assert "foo" not in alltxt assert 0 == len(find_possible_colorbars()) - g.add_colorbar(label='colors!') - assert 'colors!' in text_in_fig() + g.add_colorbar(label="colors!") + assert "colors!" in text_in_fig() assert 1 == len(find_possible_colorbars()) @pytest.mark.slow def test_set_axis_labels(self): - g = self.g.map_dataarray(xplt.contourf, 'x', 'y') - g.set_axis_labels('longitude', 'latitude') + g = self.g.map_dataarray(xplt.contourf, "x", "y") + g.set_axis_labels("longitude", "latitude") alltxt = text_in_fig() - for label in ['longitude', 'latitude']: + for label in ["longitude", "latitude"]: assert label in alltxt @pytest.mark.slow def test_facetgrid_colorbar(self): a = easy_array((10, 15, 4)) - d = DataArray(a, dims=['y', 'x', 'z'], name='foo') + d = DataArray(a, dims=["y", "x", "z"], name="foo") - d.plot.imshow(x='x', y='y', col='z') + d.plot.imshow(x="x", y="y", col="z") assert 1 == len(find_possible_colorbars()) - d.plot.imshow(x='x', y='y', col='z', add_colorbar=True) + d.plot.imshow(x="x", y="y", col="z", add_colorbar=True) assert 1 == len(find_possible_colorbars()) - d.plot.imshow(x='x', y='y', col='z', add_colorbar=False) + d.plot.imshow(x="x", y="y", col="z", add_colorbar=False) assert 0 == len(find_possible_colorbars()) @pytest.mark.slow def test_facetgrid_polar(self): # test if polar projection in FacetGrid does not raise an exception self.darray.plot.pcolormesh( - col='z', - subplot_kws=dict(projection='polar'), - sharex=False, - sharey=False) + col="z", subplot_kws=dict(projection="polar"), sharex=False, sharey=False + ) -@pytest.mark.filterwarnings('ignore:tight_layout cannot') +@pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetGrid4d(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): a = easy_array((10, 15, 3, 2)) - darray = DataArray(a, dims=['y', 'x', 'col', 'row']) - darray.coords['col'] = np.array( - ['col' + str(x) for x in darray.coords['col'].values]) - darray.coords['row'] = np.array( - ['row' + str(x) for x in darray.coords['row'].values]) + darray = DataArray(a, dims=["y", "x", "col", "row"]) + darray.coords["col"] = np.array( + ["col" + str(x) for x in darray.coords["col"].values] + ) + darray.coords["row"] = np.array( + ["row" + str(x) for x in darray.coords["row"].values] + ) self.darray = darray @pytest.mark.slow def test_default_labels(self): - g = xplt.FacetGrid(self.darray, col='col', row='row') + g = xplt.FacetGrid(self.darray, col="col", row="row") assert (2, 3) == g.axes.shape - g.map_dataarray(xplt.imshow, 'x', 'y') + g.map_dataarray(xplt.imshow, "x", "y") # Rightmost column should be labeled - for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]): + for label, ax in zip(self.darray.coords["row"].values, g.axes[:, -1]): assert substring_in_axes(label, ax) # Top row should be labeled - for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]): + for label, ax in zip(self.darray.coords["col"].values, g.axes[0, :]): assert substring_in_axes(label, ax) -@pytest.mark.filterwarnings('ignore:tight_layout cannot') +@pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetedLinePlotsLegend(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): self.darray = xr.tutorial.scatter_example_dataset() def test_legend_labels(self): - fg = self.darray.A.plot.line(col='x', row='w', hue='z') + fg = self.darray.A.plot.line(col="x", row="w", hue="z") all_legend_labels = [t.get_text() for t in fg.figlegend.texts] # labels in legend should be ['0', '1', '2', '3'] - assert sorted(all_legend_labels) == ['0', '1', '2', '3'] + assert sorted(all_legend_labels) == ["0", "1", "2", "3"] -@pytest.mark.filterwarnings('ignore:tight_layout cannot') +@pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetedLinePlots(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): - self.darray = DataArray(np.random.randn(10, 6, 3, 4), - dims=['hue', 'x', 'col', 'row'], - coords=[range(10), range(6), - range(3), ['A', 'B', 'C', 'C++']], - name='Cornelius Ortega the 1st') - - self.darray.hue.name = 'huename' - self.darray.hue.attrs['units'] = 'hunits' - self.darray.x.attrs['units'] = 'xunits' - self.darray.col.attrs['units'] = 'colunits' - self.darray.row.attrs['units'] = 'rowunits' + self.darray = DataArray( + np.random.randn(10, 6, 3, 4), + dims=["hue", "x", "col", "row"], + coords=[range(10), range(6), range(3), ["A", "B", "C", "C++"]], + name="Cornelius Ortega the 1st", + ) + + self.darray.hue.name = "huename" + self.darray.hue.attrs["units"] = "hunits" + self.darray.x.attrs["units"] = "xunits" + self.darray.col.attrs["units"] = "colunits" + self.darray.row.attrs["units"] = "rowunits" def test_facetgrid_shape(self): - g = self.darray.plot(row='row', col='col', hue='hue') + g = self.darray.plot(row="row", col="col", hue="hue") assert g.axes.shape == (len(self.darray.row), len(self.darray.col)) - g = self.darray.plot(row='col', col='row', hue='hue') + g = self.darray.plot(row="col", col="row", hue="hue") assert g.axes.shape == (len(self.darray.col), len(self.darray.row)) def test_unnamed_args(self): - g = self.darray.plot.line('o--', row='row', col='col', hue='hue') - lines = [q for q in g.axes.flat[0].get_children() - if isinstance(q, mpl.lines.Line2D)] + g = self.darray.plot.line("o--", row="row", col="col", hue="hue") + lines = [ + q for q in g.axes.flat[0].get_children() if isinstance(q, mpl.lines.Line2D) + ] # passing 'o--' as argument should set marker and linestyle - assert lines[0].get_marker() == 'o' - assert lines[0].get_linestyle() == '--' + assert lines[0].get_marker() == "o" + assert lines[0].get_linestyle() == "--" def test_default_labels(self): - g = self.darray.plot(row='row', col='col', hue='hue') + g = self.darray.plot(row="row", col="col", hue="hue") # Rightmost column should be labeled - for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]): + for label, ax in zip(self.darray.coords["row"].values, g.axes[:, -1]): assert substring_in_axes(label, ax) # Top row should be labeled - for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]): + for label, ax in zip(self.darray.coords["col"].values, g.axes[0, :]): assert substring_in_axes(str(label), ax) # Leftmost column should have array name @@ -1790,167 +1811,173 @@ def test_default_labels(self): assert substring_in_axes(self.darray.name, ax) def test_test_empty_cell(self): - g = self.darray.isel(row=1).drop('row').plot(col='col', - hue='hue', - col_wrap=2) + g = self.darray.isel(row=1).drop("row").plot(col="col", hue="hue", col_wrap=2) bottomright = g.axes[-1, -1] assert not bottomright.has_data() assert not bottomright.get_visible() def test_set_axis_labels(self): - g = self.darray.plot(row='row', col='col', hue='hue') - g.set_axis_labels('longitude', 'latitude') + g = self.darray.plot(row="row", col="col", hue="hue") + g.set_axis_labels("longitude", "latitude") alltxt = text_in_fig() - assert 'longitude' in alltxt - assert 'latitude' in alltxt + assert "longitude" in alltxt + assert "latitude" in alltxt def test_axes_in_faceted_plot(self): with pytest.raises(ValueError): - self.darray.plot.line(row='row', col='col', - x='x', ax=plt.axes()) + self.darray.plot.line(row="row", col="col", x="x", ax=plt.axes()) def test_figsize_and_size(self): with pytest.raises(ValueError): - self.darray.plot.line(row='row', col='col', - x='x', size=3, figsize=4) + self.darray.plot.line(row="row", col="col", x="x", size=3, figsize=4) def test_wrong_num_of_dimensions(self): with pytest.raises(ValueError): - self.darray.plot(row='row', hue='hue') - self.darray.plot.line(row='row', hue='hue') + self.darray.plot(row="row", hue="hue") + self.darray.plot.line(row="row", hue="hue") @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): - das = [DataArray(np.random.randn(3, 3, 4, 4), - dims=['x', 'row', 'col', 'hue'], - coords=[range(k) for k in [3, 3, 4, 4]]) - for _ in [1, 2]] - ds = Dataset({'A': das[0], 'B': das[1]}) - ds.hue.name = 'huename' - ds.hue.attrs['units'] = 'hunits' - ds.x.attrs['units'] = 'xunits' - ds.col.attrs['units'] = 'colunits' - ds.row.attrs['units'] = 'rowunits' - ds.A.attrs['units'] = 'Aunits' - ds.B.attrs['units'] = 'Bunits' + das = [ + DataArray( + np.random.randn(3, 3, 4, 4), + dims=["x", "row", "col", "hue"], + coords=[range(k) for k in [3, 3, 4, 4]], + ) + for _ in [1, 2] + ] + ds = Dataset({"A": das[0], "B": das[1]}) + ds.hue.name = "huename" + ds.hue.attrs["units"] = "hunits" + ds.x.attrs["units"] = "xunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.A.attrs["units"] = "Aunits" + ds.B.attrs["units"] = "Bunits" self.ds = ds @pytest.mark.parametrize( - 'add_guide, hue_style, legend, colorbar', [ + "add_guide, hue_style, legend, colorbar", + [ (None, None, False, True), (False, None, False, False), (True, None, False, True), (True, "continuous", False, True), (False, "discrete", False, False), - (True, "discrete", True, False)] + (True, "discrete", True, False), + ], ) def test_add_guide(self, add_guide, hue_style, legend, colorbar): - meta_data = _infer_meta_data(self.ds, x='A', y='B', hue='hue', - hue_style=hue_style, - add_guide=add_guide) - assert meta_data['add_legend'] is legend - assert meta_data['add_colorbar'] is colorbar + meta_data = _infer_meta_data( + self.ds, x="A", y="B", hue="hue", hue_style=hue_style, add_guide=add_guide + ) + assert meta_data["add_legend"] is legend + assert meta_data["add_colorbar"] is colorbar def test_facetgrid_shape(self): - g = self.ds.plot.scatter(x='A', y='B', row='row', col='col') + g = self.ds.plot.scatter(x="A", y="B", row="row", col="col") assert g.axes.shape == (len(self.ds.row), len(self.ds.col)) - g = self.ds.plot.scatter(x='A', y='B', row='col', col='row') + g = self.ds.plot.scatter(x="A", y="B", row="col", col="row") assert g.axes.shape == (len(self.ds.col), len(self.ds.row)) def test_default_labels(self): - g = self.ds.plot.scatter('A', 'B', row='row', col='col', hue='hue') + g = self.ds.plot.scatter("A", "B", row="row", col="col", hue="hue") # Top row should be labeled - for label, ax in zip(self.ds.coords['col'].values, g.axes[0, :]): + for label, ax in zip(self.ds.coords["col"].values, g.axes[0, :]): assert substring_in_axes(str(label), ax) # Bottom row should have name of x array name and units for ax in g.axes[-1, :]: - assert ax.get_xlabel() == 'A [Aunits]' + assert ax.get_xlabel() == "A [Aunits]" # Leftmost column should have name of y array name and units for ax in g.axes[:, 0]: - assert ax.get_ylabel() == 'B [Bunits]' + assert ax.get_ylabel() == "B [Bunits]" def test_axes_in_faceted_plot(self): with pytest.raises(ValueError): - self.ds.plot.scatter(x='A', y='B', row='row', ax=plt.axes()) + self.ds.plot.scatter(x="A", y="B", row="row", ax=plt.axes()) def test_figsize_and_size(self): with pytest.raises(ValueError): - self.ds.plot.scatter(x='A', y='B', row='row', size=3, figsize=4) - - @pytest.mark.parametrize('x, y, hue_style, add_guide', [ - ('A', 'B', 'something', True), - ('A', 'B', 'discrete', True), - ('A', 'B', None, True), - ('A', 'The Spanish Inquisition', None, None), - ('The Spanish Inquisition', 'B', None, True)]) + self.ds.plot.scatter(x="A", y="B", row="row", size=3, figsize=4) + + @pytest.mark.parametrize( + "x, y, hue_style, add_guide", + [ + ("A", "B", "something", True), + ("A", "B", "discrete", True), + ("A", "B", None, True), + ("A", "The Spanish Inquisition", None, None), + ("The Spanish Inquisition", "B", None, True), + ], + ) def test_bad_args(self, x, y, hue_style, add_guide): with pytest.raises(ValueError): - self.ds.plot.scatter(x, y, hue_style=hue_style, - add_guide=add_guide) + self.ds.plot.scatter(x, y, hue_style=hue_style, add_guide=add_guide) - @pytest.mark.xfail(reason='datetime,timedelta hue variable not supported.') - @pytest.mark.parametrize('hue_style', ['discrete', 'continuous']) + @pytest.mark.xfail(reason="datetime,timedelta hue variable not supported.") + @pytest.mark.parametrize("hue_style", ["discrete", "continuous"]) def test_datetime_hue(self, hue_style): ds2 = self.ds.copy() - ds2['hue'] = pd.date_range('2000-1-1', periods=4) - ds2.plot.scatter(x='A', y='B', hue='hue', hue_style=hue_style) + ds2["hue"] = pd.date_range("2000-1-1", periods=4) + ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style) - ds2['hue'] = pd.timedelta_range('-1D', periods=4, freq='D') - ds2.plot.scatter(x='A', y='B', hue='hue', hue_style=hue_style) + ds2["hue"] = pd.timedelta_range("-1D", periods=4, freq="D") + ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style) def test_facetgrid_hue_style(self): # Can't move this to pytest.mark.parametrize because py35-min # doesn't have mpl. - for hue_style, map_type in zip(['discrete', 'continuous'], - [list, mpl.collections.PathCollection]): - g = self.ds.plot.scatter(x='A', y='B', row='row', col='col', - hue='hue', hue_style=hue_style) + for hue_style, map_type in zip( + ["discrete", "continuous"], [list, mpl.collections.PathCollection] + ): + g = self.ds.plot.scatter( + x="A", y="B", row="row", col="col", hue="hue", hue_style=hue_style + ) # for 'discrete' a list is appended to _mappables # for 'continuous', should be single PathCollection assert isinstance(g._mappables[-1], map_type) - @pytest.mark.parametrize('x, y, hue, markersize', [ - ('A', 'B', 'x', 'col'), - ('x', 'row', 'A', 'B')]) + @pytest.mark.parametrize( + "x, y, hue, markersize", [("A", "B", "x", "col"), ("x", "row", "A", "B")] + ) def test_scatter(self, x, y, hue, markersize): self.ds.plot.scatter(x, y, hue=hue, markersize=markersize) def test_non_numeric_legend(self): ds2 = self.ds.copy() - ds2['hue'] = ['a', 'b', 'c', 'd'] - lines = ds2.plot.scatter(x='A', y='B', hue='hue') + ds2["hue"] = ["a", "b", "c", "d"] + lines = ds2.plot.scatter(x="A", y="B", hue="hue") # should make a discrete legend assert lines[0].axes.legend_ is not None # and raise an error if explicitly not allowed to do so with pytest.raises(ValueError): - ds2.plot.scatter(x='A', y='B', hue='hue', - hue_style='continuous') + ds2.plot.scatter(x="A", y="B", hue="hue", hue_style="continuous") def test_add_legend_by_default(self): - sc = self.ds.plot.scatter(x='A', y='B', hue='hue') + sc = self.ds.plot.scatter(x="A", y="B", hue="hue") assert len(sc.figure.axes) == 2 class TestDatetimePlot(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): - ''' + """ Create a DataArray with a time-axis that contains datetime objects. - ''' + """ month = np.arange(1, 13, 1) data = np.sin(2 * np.pi * month / 12.0) - darray = DataArray(data, dims=['time']) - darray.coords['time'] = np.array([datetime(2017, m, 1) for m in month]) + darray = DataArray(data, dims=["time"]) + darray.coords["time"] = np.array([datetime(2017, m, 1) for m in month]) self.darray = darray @@ -1964,18 +1991,15 @@ def test_datetime_line_plot(self): class TestCFDatetimePlot(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): - ''' + """ Create a DataArray with a time-axis that contains cftime.datetime objects. - ''' + """ # case for 1d array data = np.random.rand(4, 12) - time = xr.cftime_range(start='2017', - periods=12, - freq='1M', - calendar='noleap') - darray = DataArray(data, dims=['x', 'time']) - darray.coords['time'] = time + time = xr.cftime_range(start="2017", periods=12, freq="1M", calendar="noleap") + darray = DataArray(data, dims=["x", "time"]) + darray.coords["time"] = time self.darray = darray @@ -1990,27 +2014,25 @@ def test_cfdatetime_contour_plot(self): @requires_cftime -@pytest.mark.skipif(has_nc_time_axis, reason='nc_time_axis is installed') +@pytest.mark.skipif(has_nc_time_axis, reason="nc_time_axis is installed") class TestNcAxisNotInstalled(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): - ''' + """ Create a DataArray with a time-axis that contains cftime.datetime objects. - ''' + """ month = np.arange(1, 13, 1) data = np.sin(2 * np.pi * month / 12.0) - darray = DataArray(data, dims=['time']) - darray.coords['time'] = xr.cftime_range(start='2017', - periods=12, - freq='1M', - calendar='noleap') + darray = DataArray(data, dims=["time"]) + darray.coords["time"] = xr.cftime_range( + start="2017", periods=12, freq="1M", calendar="noleap" + ) self.darray = darray def test_ncaxis_notinstalled_line_plot(self): - with raises_regex(ImportError, - 'optional `nc-time-axis`'): + with raises_regex(ImportError, "optional `nc-time-axis`"): self.darray.plot.line() @@ -2026,68 +2048,71 @@ def test_import_seaborn_no_warning(): def test_plot_seaborn_no_import_warning(): # GH1633 with pytest.warns(None) as record: - _color_palette('Blues', 4) + _color_palette("Blues", 4) assert len(record) == 0 -test_da_list = [DataArray(easy_array((10, ))), - DataArray(easy_array((10, 3))), - DataArray(easy_array((10, 3, 2)))] +test_da_list = [ + DataArray(easy_array((10,))), + DataArray(easy_array((10, 3))), + DataArray(easy_array((10, 3, 2))), +] @requires_matplotlib class TestAxesKwargs: - @pytest.mark.parametrize('da', test_da_list) - @pytest.mark.parametrize('xincrease', [True, False]) + @pytest.mark.parametrize("da", test_da_list) + @pytest.mark.parametrize("xincrease", [True, False]) def test_xincrease_kwarg(self, da, xincrease): plt.clf() da.plot(xincrease=xincrease) assert plt.gca().xaxis_inverted() == (not xincrease) - @pytest.mark.parametrize('da', test_da_list) - @pytest.mark.parametrize('yincrease', [True, False]) + @pytest.mark.parametrize("da", test_da_list) + @pytest.mark.parametrize("yincrease", [True, False]) def test_yincrease_kwarg(self, da, yincrease): plt.clf() da.plot(yincrease=yincrease) assert plt.gca().yaxis_inverted() == (not yincrease) - @pytest.mark.parametrize('da', test_da_list) - @pytest.mark.parametrize('xscale', ['linear', 'log', 'logit', 'symlog']) + @pytest.mark.parametrize("da", test_da_list) + @pytest.mark.parametrize("xscale", ["linear", "log", "logit", "symlog"]) def test_xscale_kwarg(self, da, xscale): plt.clf() da.plot(xscale=xscale) assert plt.gca().get_xscale() == xscale - @pytest.mark.parametrize('da', [DataArray(easy_array((10, ))), - DataArray(easy_array((10, 3)))]) - @pytest.mark.parametrize('yscale', ['linear', 'log', 'logit', 'symlog']) + @pytest.mark.parametrize( + "da", [DataArray(easy_array((10,))), DataArray(easy_array((10, 3)))] + ) + @pytest.mark.parametrize("yscale", ["linear", "log", "logit", "symlog"]) def test_yscale_kwarg(self, da, yscale): plt.clf() da.plot(yscale=yscale) assert plt.gca().get_yscale() == yscale - @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize("da", test_da_list) def test_xlim_kwarg(self, da): plt.clf() expected = (0.0, 1000.0) da.plot(xlim=[0, 1000]) assert plt.gca().get_xlim() == expected - @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize("da", test_da_list) def test_ylim_kwarg(self, da): plt.clf() da.plot(ylim=[0, 1000]) expected = (0.0, 1000.0) assert plt.gca().get_ylim() == expected - @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize("da", test_da_list) def test_xticks_kwarg(self, da): plt.clf() da.plot(xticks=np.arange(5)) expected = np.arange(5).tolist() assert np.all(plt.gca().get_xticks() == expected) - @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize("da", test_da_list) def test_yticks_kwarg(self, da): plt.clf() da.plot(yticks=np.arange(5)) diff --git a/xarray/tests/test_print_versions.py b/xarray/tests/test_print_versions.py index b1755ac289c..01c30e5e301 100644 --- a/xarray/tests/test_print_versions.py +++ b/xarray/tests/test_print_versions.py @@ -6,4 +6,4 @@ def test_show_versions(): f = io.StringIO() xarray.show_versions(file=f) - assert 'INSTALLED VERSIONS' in f.getvalue() + assert "INSTALLED VERSIONS" in f.getvalue() diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index 329952bc064..4014d8a66e6 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -13,8 +13,13 @@ import xarray.ufuncs as xu from . import ( - assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, - assert_identical, raises_regex) + assert_allclose, + assert_array_equal, + assert_equal, + assert_frame_equal, + assert_identical, + raises_regex, +) import pytest @@ -22,10 +27,11 @@ xfail = pytest.mark.xfail if not IS_NEP18_ACTIVE: - pytest.skip("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled", - allow_module_level=True) + pytest.skip( + "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled", allow_module_level=True + ) -sparse = pytest.importorskip('sparse') +sparse = pytest.importorskip("sparse") from sparse.utils import assert_eq as assert_sparse_eq # noqa from sparse import COO, SparseArray # noqa @@ -40,18 +46,19 @@ def make_sparray(shape): def make_xrvar(dim_lengths): return xr.Variable( - tuple(dim_lengths.keys()), - make_sparray(shape=tuple(dim_lengths.values()))) + tuple(dim_lengths.keys()), make_sparray(shape=tuple(dim_lengths.values())) + ) -def make_xrarray(dim_lengths, coords=None, name='test'): +def make_xrarray(dim_lengths, coords=None, name="test"): if coords is None: coords = {d: np.arange(n) for d, n in dim_lengths.items()} return xr.DataArray( make_sparray(shape=tuple(dim_lengths.values())), dims=tuple(coords.keys()), coords=coords, - name=name) + name=name, + ) class do: @@ -64,109 +71,169 @@ def __call__(self, obj): return getattr(obj, self.meth)(*self.args, **self.kwargs) def __repr__(self): - return 'obj.{}(*{}, **{})'.format(self.meth, self.args, self.kwargs) - - -@pytest.mark.parametrize("prop", [ - 'chunks', - 'data', - 'dims', - 'dtype', - 'encoding', - 'imag', - 'nbytes', - 'ndim', - param('values', marks=xfail(reason='Coercion to dense')) -]) + return "obj.{}(*{}, **{})".format(self.meth, self.args, self.kwargs) + + +@pytest.mark.parametrize( + "prop", + [ + "chunks", + "data", + "dims", + "dtype", + "encoding", + "imag", + "nbytes", + "ndim", + param("values", marks=xfail(reason="Coercion to dense")), + ], +) def test_variable_property(prop): - var = make_xrvar({'x': 10, 'y': 5}) + var = make_xrvar({"x": 10, "y": 5}) getattr(var, prop) -@pytest.mark.parametrize("func,sparse_output", [ - (do('all'), False), - (do('any'), False), - (do('astype', dtype=int), True), - (do('broadcast_equals', make_xrvar({'x': 10, 'y': 5})), False), - (do('clip', min=0, max=1), True), - (do('coarsen', windows={'x': 2}, func=np.sum), True), - (do('compute'), True), - (do('conj'), True), - (do('copy'), True), - (do('count'), False), - (do('equals', make_xrvar({'x': 10, 'y': 5})), False), - (do('get_axis_num', dim='x'), False), - (do('identical', other=make_xrvar({'x': 10, 'y': 5})), False), - (do('isel', x=slice(2, 4)), True), - (do('isnull'), True), - (do('load'), True), - (do('mean'), False), - (do('notnull'), True), - (do('roll'), True), - (do('round'), True), - (do('set_dims', dims=('x', 'y', 'z')), True), - (do('stack', dimensions={'flat': ('x', 'y')}), True), - (do('to_base_variable'), True), - (do('transpose'), True), - (do('unstack', dimensions={'x': {'x1': 5, 'x2': 2}}), True), - - param(do('argmax'), True, - marks=xfail(reason='Missing implementation for np.argmin')), - param(do('argmin'), True, - marks=xfail(reason='Missing implementation for np.argmax')), - param(do('argsort'), True, - marks=xfail(reason="'COO' object has no attribute 'argsort'")), - param(do('chunk', chunks=(5, 5)), True, - marks=xfail), - param(do('concat', variables=[make_xrvar({'x': 10, 'y': 5}), - make_xrvar({'x': 10, 'y': 5})]), True, - marks=xfail(reason='Coercion to dense')), - param(do('conjugate'), True, - marks=xfail(reason="'COO' object has no attribute 'conjugate'")), - param(do('cumprod'), True, - marks=xfail(reason='Missing implementation for np.nancumprod')), - param(do('cumsum'), True, - marks=xfail(reason='Missing implementation for np.nancumsum')), - param(do('fillna', 0), True, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('item', (1, 1)), False, - marks=xfail(reason="'COO' object has no attribute 'item'")), - param(do('max'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('median'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('min'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('no_conflicts', other=make_xrvar({'x': 10, 'y': 5})), True, - marks=xfail(reason='mixed sparse-dense operation')), - param(do('pad_with_fill_value', pad_widths={'x': (1, 1)}, fill_value=5), True, # noqa - marks=xfail(reason='Missing implementation for np.pad')), - param(do('prod'), False, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('quantile', q=0.5), True, - marks=xfail(reason='Missing implementation for np.nanpercentile')), - param(do('rank', dim='x'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('reduce', func=np.sum, dim='x'), True, - marks=xfail(reason='Coercion to dense')), - param(do('rolling_window', dim='x', window=2, window_dim='x_win'), True, - marks=xfail(reason='Missing implementation for np.pad')), - param(do('shift', x=2), True, - marks=xfail(reason='mixed sparse-dense operation')), - param(do('std'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('sum'), False, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('var'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('to_dict'), False, - marks=xfail(reason='Coercion to dense')), - param(do('where', cond=make_xrvar({'x': 10, 'y': 5}) > 0.5), True, - marks=xfail(reason='Coercion of dense to sparse when using sparse mask')), # noqa -], -ids=repr) +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("all"), False), + (do("any"), False), + (do("astype", dtype=int), True), + (do("broadcast_equals", make_xrvar({"x": 10, "y": 5})), False), + (do("clip", min=0, max=1), True), + (do("coarsen", windows={"x": 2}, func=np.sum), True), + (do("compute"), True), + (do("conj"), True), + (do("copy"), True), + (do("count"), False), + (do("equals", make_xrvar({"x": 10, "y": 5})), False), + (do("get_axis_num", dim="x"), False), + (do("identical", other=make_xrvar({"x": 10, "y": 5})), False), + (do("isel", x=slice(2, 4)), True), + (do("isnull"), True), + (do("load"), True), + (do("mean"), False), + (do("notnull"), True), + (do("roll"), True), + (do("round"), True), + (do("set_dims", dims=("x", "y", "z")), True), + (do("stack", dimensions={"flat": ("x", "y")}), True), + (do("to_base_variable"), True), + (do("transpose"), True), + (do("unstack", dimensions={"x": {"x1": 5, "x2": 2}}), True), + param( + do("argmax"), + True, + marks=xfail(reason="Missing implementation for np.argmin"), + ), + param( + do("argmin"), + True, + marks=xfail(reason="Missing implementation for np.argmax"), + ), + param( + do("argsort"), + True, + marks=xfail(reason="'COO' object has no attribute 'argsort'"), + ), + param(do("chunk", chunks=(5, 5)), True, marks=xfail), + param( + do( + "concat", + variables=[ + make_xrvar({"x": 10, "y": 5}), + make_xrvar({"x": 10, "y": 5}), + ], + ), + True, + marks=xfail(reason="Coercion to dense"), + ), + param( + do("conjugate"), + True, + marks=xfail(reason="'COO' object has no attribute 'conjugate'"), + ), + param( + do("cumprod"), + True, + marks=xfail(reason="Missing implementation for np.nancumprod"), + ), + param( + do("cumsum"), + True, + marks=xfail(reason="Missing implementation for np.nancumsum"), + ), + param( + do("fillna", 0), + True, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param( + do("item", (1, 1)), + False, + marks=xfail(reason="'COO' object has no attribute 'item'"), + ), + param(do("max"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("median"), False, marks=xfail(reason="Coercion to dense via bottleneck") + ), + param(do("min"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("no_conflicts", other=make_xrvar({"x": 10, "y": 5})), + True, + marks=xfail(reason="mixed sparse-dense operation"), + ), + param( + do("pad_with_fill_value", pad_widths={"x": (1, 1)}, fill_value=5), + True, # noqa + marks=xfail(reason="Missing implementation for np.pad"), + ), + param( + do("prod"), + False, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param( + do("quantile", q=0.5), + True, + marks=xfail(reason="Missing implementation for np.nanpercentile"), + ), + param( + do("rank", dim="x"), + False, + marks=xfail(reason="Coercion to dense via bottleneck"), + ), + param( + do("reduce", func=np.sum, dim="x"), + True, + marks=xfail(reason="Coercion to dense"), + ), + param( + do("rolling_window", dim="x", window=2, window_dim="x_win"), + True, + marks=xfail(reason="Missing implementation for np.pad"), + ), + param( + do("shift", x=2), True, marks=xfail(reason="mixed sparse-dense operation") + ), + param(do("std"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("sum"), + False, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param(do("var"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param(do("to_dict"), False, marks=xfail(reason="Coercion to dense")), + param( + do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), + True, + marks=xfail(reason="Coercion of dense to sparse when using sparse mask"), + ), # noqa + ], + ids=repr, +) def test_variable_method(func, sparse_output): - var_s = make_xrvar({'x': 10, 'y': 5}) + var_s = make_xrvar({"x": 10, "y": 5}) var_d = xr.Variable(var_s.dims, var_s.data.todense()) ret_s = func(var_s) ret_d = func(var_d) @@ -178,18 +245,21 @@ def test_variable_method(func, sparse_output): assert np.allclose(ret_s, ret_d, equal_nan=True) -@pytest.mark.parametrize("func,sparse_output", [ - (do('squeeze'), True), - - param(do('to_index'), False, - marks=xfail(reason='Coercion to dense')), - param(do('to_index_variable'), False, - marks=xfail(reason='Coercion to dense')), - param(do('searchsorted', 0.5), True, - marks=xfail(reason="'COO' object has no attribute 'searchsorted'")), -]) +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("squeeze"), True), + param(do("to_index"), False, marks=xfail(reason="Coercion to dense")), + param(do("to_index_variable"), False, marks=xfail(reason="Coercion to dense")), + param( + do("searchsorted", 0.5), + True, + marks=xfail(reason="'COO' object has no attribute 'searchsorted'"), + ), + ], +) def test_1d_variable_method(func, sparse_output): - var_s = make_xrvar({'x': 10}) + var_s = make_xrvar({"x": 10}) var_d = xr.Variable(var_s.dims, var_s.data.todense()) ret_s = func(var_s) ret_d = func(var_d) @@ -205,7 +275,7 @@ class TestSparseVariable: @pytest.fixture(autouse=True) def setUp(self): self.data = sparse.random((4, 6), random_state=0, density=0.5) - self.var = xr.Variable(('x', 'y'), self.data) + self.var = xr.Variable(("x", "y"), self.data) def test_unary_op(self): assert_sparse_eq(-self.var.data, -self.data) @@ -216,15 +286,15 @@ def test_univariate_ufunc(self): assert_sparse_eq(np.sin(self.data), xu.sin(self.var).data) def test_bivariate_ufunc(self): - assert_sparse_eq(np.maximum(self.data, 0), - xu.maximum(self.var, 0).data) - assert_sparse_eq(np.maximum(self.data, 0), - xu.maximum(0, self.var).data) + assert_sparse_eq(np.maximum(self.data, 0), xu.maximum(self.var, 0).data) + assert_sparse_eq(np.maximum(self.data, 0), xu.maximum(0, self.var).data) def test_repr(self): - expected = dedent("""\ + expected = dedent( + """\ - """) + """ + ) assert expected == repr(self.var) def test_pickle(self): @@ -236,167 +306,251 @@ def test_pickle(self): def test_missing_values(self): a = np.array([0, 1, np.nan, 3]) s = COO.from_numpy(a) - var_s = Variable('x', s) + var_s = Variable("x", s) assert np.all(var_s.fillna(2).data.todense() == np.arange(4)) assert np.all(var_s.count() == 3) -@pytest.mark.parametrize("prop", [ - 'attrs', - 'chunks', - 'coords', - 'data', - 'dims', - 'dtype', - 'encoding', - 'imag', - 'indexes', - 'loc', - 'name', - 'nbytes', - 'ndim', - 'plot', - 'real', - 'shape', - 'size', - 'sizes', - 'str', - 'variable', -]) +@pytest.mark.parametrize( + "prop", + [ + "attrs", + "chunks", + "coords", + "data", + "dims", + "dtype", + "encoding", + "imag", + "indexes", + "loc", + "name", + "nbytes", + "ndim", + "plot", + "real", + "shape", + "size", + "sizes", + "str", + "variable", + ], +) def test_dataarray_property(prop): - arr = make_xrarray({'x': 10, 'y': 5}) + arr = make_xrarray({"x": 10, "y": 5}) getattr(arr, prop) -@pytest.mark.parametrize("func,sparse_output", [ - (do('all'), False), - (do('any'), False), - (do('assign_attrs', {'foo': 'bar'}), True), - (do('assign_coords', x=make_xrarray({'x': 10}).x + 1), True), - (do('astype', int), True), - (do('broadcast_equals', make_xrarray({'x': 10, 'y': 5})), False), - (do('clip', min=0, max=1), True), - (do('compute'), True), - (do('conj'), True), - (do('copy'), True), - (do('count'), False), - (do('diff', 'x'), True), - (do('drop', 'x'), True), - (do('equals', make_xrarray({'x': 10, 'y': 5})), False), - (do('expand_dims', {'z': 2}, axis=2), True), - (do('get_axis_num', 'x'), False), - (do('get_index', 'x'), False), - (do('identical', make_xrarray({'x': 5, 'y': 5})), False), - (do('integrate', 'x'), True), - (do('isel', {'x': slice(0, 3), 'y': slice(2, 4)}), True), - (do('isnull'), True), - (do('load'), True), - (do('mean'), False), - (do('persist'), True), - (do('reindex', {'x': [1, 2, 3]}), True), - (do('rename', 'foo'), True), - (do('reorder_levels'), True), - (do('reset_coords', drop=True), True), - (do('reset_index', 'x'), True), - (do('round'), True), - (do('sel', x=[0, 1, 2]), True), - (do('shift'), True), - (do('sortby', 'x', ascending=False), True), - (do('stack', z={'x', 'y'}), True), - (do('transpose'), True), - - # TODO - # isel_points - # sel_points - # set_index - # swap_dims - - param(do('argmax'), True, - marks=xfail(reason='Missing implementation for np.argmax')), - param(do('argmin'), True, - marks=xfail(reason='Missing implementation for np.argmin')), - param(do('argsort'), True, - marks=xfail(reason="'COO' object has no attribute 'argsort'")), - param(do('bfill', dim='x'), False, - marks=xfail(reason='Missing implementation for np.flip')), - param(do('chunk', chunks=(5, 5)), False, - marks=xfail(reason='Coercion to dense')), - param(do('combine_first', make_xrarray({'x': 10, 'y': 5})), True, - marks=xfail(reason='mixed sparse-dense operation')), - param(do('conjugate'), False, - marks=xfail(reason="'COO' object has no attribute 'conjugate'")), - param(do('cumprod'), True, - marks=xfail(reason='Missing implementation for np.nancumprod')), - param(do('cumsum'), True, - marks=xfail(reason='Missing implementation for np.nancumsum')), - param(do('differentiate', 'x'), False, - marks=xfail(reason='Missing implementation for np.gradient')), - param(do('dot', make_xrarray({'x': 10, 'y': 5})), True, - marks=xfail(reason='Missing implementation for np.einsum')), - param(do('dropna', 'x'), False, - marks=xfail(reason='Coercion to dense')), - param(do('ffill', 'x'), False, - marks=xfail(reason='Coercion to dense via bottleneck.push')), - param(do('fillna', 0), True, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('interp', coords={'x': np.arange(10) + 0.5}), True, - marks=xfail(reason='Coercion to dense')), - param(do('interp_like', - make_xrarray({'x': 10, 'y': 5}, - coords={'x': np.arange(10) + 0.5, - 'y': np.arange(5) + 0.5})), True, - marks=xfail(reason='Indexing COO with more than one iterable index')), # noqa - param(do('interpolate_na', 'x'), True, - marks=xfail(reason='Coercion to dense')), - param(do('isin', [1, 2, 3]), False, - marks=xfail(reason='Missing implementation for np.isin')), - param(do('item', (1, 1)), False, - marks=xfail(reason="'COO' object has no attribute 'item'")), - param(do('max'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('median'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('min'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('notnull'), False, - marks=xfail(reason="'COO' object has no attribute 'notnull'")), - param(do('pipe', np.sum, axis=1), True, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('prod'), False, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('quantile', q=0.5), False, - marks=xfail(reason='Missing implementation for np.nanpercentile')), - param(do('rank', 'x'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('reduce', np.sum, dim='x'), False, - marks=xfail(reason='Coercion to dense')), - param(do('reindex_like', - make_xrarray({'x': 10, 'y': 5}, - coords={'x': np.arange(10) + 0.5, - 'y': np.arange(5) + 0.5})), - True, - marks=xfail(reason='Indexing COO with more than one iterable index')), # noqa - param(do('roll', x=2), True, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('sel', x=[0, 1, 2], y=[2, 3]), True, - marks=xfail(reason='Indexing COO with more than one iterable index')), # noqa - param(do('std'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('sum'), False, - marks=xfail(reason='Missing implementation for np.result_type')), - param(do('var'), False, - marks=xfail(reason='Coercion to dense via bottleneck')), - param(do('where', make_xrarray({'x': 10, 'y': 5}) > 0.5), False, - marks=xfail(reason='Conversion of dense to sparse when using sparse mask')), # noqa -], -ids=repr) +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("all"), False), + (do("any"), False), + (do("assign_attrs", {"foo": "bar"}), True), + (do("assign_coords", x=make_xrarray({"x": 10}).x + 1), True), + (do("astype", int), True), + (do("broadcast_equals", make_xrarray({"x": 10, "y": 5})), False), + (do("clip", min=0, max=1), True), + (do("compute"), True), + (do("conj"), True), + (do("copy"), True), + (do("count"), False), + (do("diff", "x"), True), + (do("drop", "x"), True), + (do("equals", make_xrarray({"x": 10, "y": 5})), False), + (do("expand_dims", {"z": 2}, axis=2), True), + (do("get_axis_num", "x"), False), + (do("get_index", "x"), False), + (do("identical", make_xrarray({"x": 5, "y": 5})), False), + (do("integrate", "x"), True), + (do("isel", {"x": slice(0, 3), "y": slice(2, 4)}), True), + (do("isnull"), True), + (do("load"), True), + (do("mean"), False), + (do("persist"), True), + (do("reindex", {"x": [1, 2, 3]}), True), + (do("rename", "foo"), True), + (do("reorder_levels"), True), + (do("reset_coords", drop=True), True), + (do("reset_index", "x"), True), + (do("round"), True), + (do("sel", x=[0, 1, 2]), True), + (do("shift"), True), + (do("sortby", "x", ascending=False), True), + (do("stack", z={"x", "y"}), True), + (do("transpose"), True), + # TODO + # isel_points + # sel_points + # set_index + # swap_dims + param( + do("argmax"), + True, + marks=xfail(reason="Missing implementation for np.argmax"), + ), + param( + do("argmin"), + True, + marks=xfail(reason="Missing implementation for np.argmin"), + ), + param( + do("argsort"), + True, + marks=xfail(reason="'COO' object has no attribute 'argsort'"), + ), + param( + do("bfill", dim="x"), + False, + marks=xfail(reason="Missing implementation for np.flip"), + ), + param( + do("chunk", chunks=(5, 5)), False, marks=xfail(reason="Coercion to dense") + ), + param( + do("combine_first", make_xrarray({"x": 10, "y": 5})), + True, + marks=xfail(reason="mixed sparse-dense operation"), + ), + param( + do("conjugate"), + False, + marks=xfail(reason="'COO' object has no attribute 'conjugate'"), + ), + param( + do("cumprod"), + True, + marks=xfail(reason="Missing implementation for np.nancumprod"), + ), + param( + do("cumsum"), + True, + marks=xfail(reason="Missing implementation for np.nancumsum"), + ), + param( + do("differentiate", "x"), + False, + marks=xfail(reason="Missing implementation for np.gradient"), + ), + param( + do("dot", make_xrarray({"x": 10, "y": 5})), + True, + marks=xfail(reason="Missing implementation for np.einsum"), + ), + param(do("dropna", "x"), False, marks=xfail(reason="Coercion to dense")), + param( + do("ffill", "x"), + False, + marks=xfail(reason="Coercion to dense via bottleneck.push"), + ), + param( + do("fillna", 0), + True, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param( + do("interp", coords={"x": np.arange(10) + 0.5}), + True, + marks=xfail(reason="Coercion to dense"), + ), + param( + do( + "interp_like", + make_xrarray( + {"x": 10, "y": 5}, + coords={"x": np.arange(10) + 0.5, "y": np.arange(5) + 0.5}, + ), + ), + True, + marks=xfail(reason="Indexing COO with more than one iterable index"), + ), # noqa + param(do("interpolate_na", "x"), True, marks=xfail(reason="Coercion to dense")), + param( + do("isin", [1, 2, 3]), + False, + marks=xfail(reason="Missing implementation for np.isin"), + ), + param( + do("item", (1, 1)), + False, + marks=xfail(reason="'COO' object has no attribute 'item'"), + ), + param(do("max"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("median"), False, marks=xfail(reason="Coercion to dense via bottleneck") + ), + param(do("min"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("notnull"), + False, + marks=xfail(reason="'COO' object has no attribute 'notnull'"), + ), + param( + do("pipe", np.sum, axis=1), + True, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param( + do("prod"), + False, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param( + do("quantile", q=0.5), + False, + marks=xfail(reason="Missing implementation for np.nanpercentile"), + ), + param( + do("rank", "x"), + False, + marks=xfail(reason="Coercion to dense via bottleneck"), + ), + param( + do("reduce", np.sum, dim="x"), + False, + marks=xfail(reason="Coercion to dense"), + ), + param( + do( + "reindex_like", + make_xrarray( + {"x": 10, "y": 5}, + coords={"x": np.arange(10) + 0.5, "y": np.arange(5) + 0.5}, + ), + ), + True, + marks=xfail(reason="Indexing COO with more than one iterable index"), + ), # noqa + param( + do("roll", x=2), + True, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param( + do("sel", x=[0, 1, 2], y=[2, 3]), + True, + marks=xfail(reason="Indexing COO with more than one iterable index"), + ), # noqa + param(do("std"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("sum"), + False, + marks=xfail(reason="Missing implementation for np.result_type"), + ), + param(do("var"), False, marks=xfail(reason="Coercion to dense via bottleneck")), + param( + do("where", make_xrarray({"x": 10, "y": 5}) > 0.5), + False, + marks=xfail(reason="Conversion of dense to sparse when using sparse mask"), + ), # noqa + ], + ids=repr, +) def test_dataarray_method(func, sparse_output): - arr_s = make_xrarray({'x': 10, 'y': 5}, - coords={'x': np.arange(10), 'y': np.arange(5)}) - arr_d = xr.DataArray( - arr_s.data.todense(), - coords=arr_s.coords, - dims=arr_s.dims) + arr_s = make_xrarray( + {"x": 10, "y": 5}, coords={"x": np.arange(10), "y": np.arange(5)} + ) + arr_d = xr.DataArray(arr_s.data.todense(), coords=arr_s.coords, dims=arr_s.dims) ret_s = func(arr_s) ret_d = func(arr_d) @@ -407,17 +561,20 @@ def test_dataarray_method(func, sparse_output): assert np.allclose(ret_s, ret_d, equal_nan=True) -@pytest.mark.parametrize("func,sparse_output", [ - (do('squeeze'), True), - param(do('searchsorted', [1, 2, 3]), False, - marks=xfail(reason="'COO' object has no attribute 'searchsorted'")), -]) +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("squeeze"), True), + param( + do("searchsorted", [1, 2, 3]), + False, + marks=xfail(reason="'COO' object has no attribute 'searchsorted'"), + ), + ], +) def test_datarray_1d_method(func, sparse_output): - arr_s = make_xrarray({'x': 10}, coords={'x': np.arange(10)}) - arr_d = xr.DataArray( - arr_s.data.todense(), - coords=arr_s.coords, - dims=arr_s.dims) + arr_s = make_xrarray({"x": 10}, coords={"x": np.arange(10)}) + arr_d = xr.DataArray(arr_s.data.todense(), coords=arr_s.coords, dims=arr_s.dims) ret_s = func(arr_s) ret_d = func(arr_d) @@ -432,97 +589,100 @@ class TestSparseDataArrayAndDataset: @pytest.fixture(autouse=True) def setUp(self): self.sp_ar = sparse.random((4, 6), random_state=0, density=0.5) - self.sp_xr = xr.DataArray(self.sp_ar, coords={'x': range(4)}, - dims=('x', 'y'), name='foo') + self.sp_xr = xr.DataArray( + self.sp_ar, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) self.ds_ar = self.sp_ar.todense() - self.ds_xr = xr.DataArray(self.ds_ar, coords={'x': range(4)}, - dims=('x', 'y'), name='foo') + self.ds_xr = xr.DataArray( + self.ds_ar, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) - @pytest.mark.xfail(reason='Missing implementation for np.result_type') + @pytest.mark.xfail(reason="Missing implementation for np.result_type") def test_to_dataset_roundtrip(self): x = self.sp_xr - assert_equal(x, x.to_dataset('x').to_array('x')) + assert_equal(x, x.to_dataset("x").to_array("x")) def test_align(self): a1 = xr.DataArray( - COO.from_numpy(np.arange(4)), - dims=['x'], - coords={'x': ['a', 'b', 'c', 'd']}) + COO.from_numpy(np.arange(4)), dims=["x"], coords={"x": ["a", "b", "c", "d"]} + ) b1 = xr.DataArray( - COO.from_numpy(np.arange(4)), - dims=['x'], - coords={'x': ['a', 'b', 'd', 'e']}) - a2, b2 = xr.align(a1, b1, join='inner') + COO.from_numpy(np.arange(4)), dims=["x"], coords={"x": ["a", "b", "d", "e"]} + ) + a2, b2 = xr.align(a1, b1, join="inner") assert isinstance(a2.data, sparse.SparseArray) assert isinstance(b2.data, sparse.SparseArray) - assert np.all(a2.coords['x'].data == ['a', 'b', 'd']) - assert np.all(b2.coords['x'].data == ['a', 'b', 'd']) + assert np.all(a2.coords["x"].data == ["a", "b", "d"]) + assert np.all(b2.coords["x"].data == ["a", "b", "d"]) @pytest.mark.xfail( reason="COO objects currently do not accept more than one " - "iterable index at a time") + "iterable index at a time" + ) def test_align_2d(self): - A1 = xr.DataArray(self.sp_ar, dims=['x', 'y'], coords={ - 'x': np.arange(self.sp_ar.shape[0]), - 'y': np.arange(self.sp_ar.shape[1]) - }) - - A2 = xr.DataArray(self.sp_ar, dims=['x', 'y'], coords={ - 'x': np.arange(1, self.sp_ar.shape[0] + 1), - 'y': np.arange(1, self.sp_ar.shape[1] + 1) - }) - - B1, B2 = xr.align(A1, A2, join='inner') - assert np.all(B1.coords['x'] == np.arange(1, self.sp_ar.shape[0])) - assert np.all(B1.coords['y'] == np.arange(1, self.sp_ar.shape[0])) - assert np.all(B1.coords['x'] == B2.coords['x']) - assert np.all(B1.coords['y'] == B2.coords['y']) + A1 = xr.DataArray( + self.sp_ar, + dims=["x", "y"], + coords={ + "x": np.arange(self.sp_ar.shape[0]), + "y": np.arange(self.sp_ar.shape[1]), + }, + ) + + A2 = xr.DataArray( + self.sp_ar, + dims=["x", "y"], + coords={ + "x": np.arange(1, self.sp_ar.shape[0] + 1), + "y": np.arange(1, self.sp_ar.shape[1] + 1), + }, + ) + + B1, B2 = xr.align(A1, A2, join="inner") + assert np.all(B1.coords["x"] == np.arange(1, self.sp_ar.shape[0])) + assert np.all(B1.coords["y"] == np.arange(1, self.sp_ar.shape[0])) + assert np.all(B1.coords["x"] == B2.coords["x"]) + assert np.all(B1.coords["y"] == B2.coords["y"]) @pytest.mark.xfail(reason="fill value leads to sparse-dense operation") def test_align_outer(self): a1 = xr.DataArray( - COO.from_numpy(np.arange(4)), - dims=['x'], - coords={'x': ['a', 'b', 'c', 'd']}) + COO.from_numpy(np.arange(4)), dims=["x"], coords={"x": ["a", "b", "c", "d"]} + ) b1 = xr.DataArray( - COO.from_numpy(np.arange(4)), - dims=['x'], - coords={'x': ['a', 'b', 'd', 'e']}) - a2, b2 = xr.align(a1, b1, join='outer') + COO.from_numpy(np.arange(4)), dims=["x"], coords={"x": ["a", "b", "d", "e"]} + ) + a2, b2 = xr.align(a1, b1, join="outer") assert isinstance(a2.data, sparse.SparseArray) assert isinstance(b2.data, sparse.SparseArray) - assert np.all(a2.coords['x'].data == ['a', 'b', 'c', 'd']) - assert np.all(b2.coords['x'].data == ['a', 'b', 'c', 'd']) + assert np.all(a2.coords["x"].data == ["a", "b", "c", "d"]) + assert np.all(b2.coords["x"].data == ["a", "b", "c", "d"]) - @pytest.mark.xfail(reason='Missing implementation for np.result_type') + @pytest.mark.xfail(reason="Missing implementation for np.result_type") def test_concat(self): - ds1 = xr.Dataset(data_vars={'d': self.sp_xr}) - ds2 = xr.Dataset(data_vars={'d': self.sp_xr}) - ds3 = xr.Dataset(data_vars={'d': self.sp_xr}) - out = xr.concat([ds1, ds2, ds3], dim='x') + ds1 = xr.Dataset(data_vars={"d": self.sp_xr}) + ds2 = xr.Dataset(data_vars={"d": self.sp_xr}) + ds3 = xr.Dataset(data_vars={"d": self.sp_xr}) + out = xr.concat([ds1, ds2, ds3], dim="x") assert_sparse_eq( - out['d'].data, - sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=0) + out["d"].data, + sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=0), ) - out = xr.concat([self.sp_xr, self.sp_xr, self.sp_xr], dim='y') + out = xr.concat([self.sp_xr, self.sp_xr, self.sp_xr], dim="y") assert_sparse_eq( - out.data, - sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=1) + out.data, sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=1) ) def test_stack(self): - arr = make_xrarray({'w': 2, 'x': 3, 'y': 4}) - stacked = arr.stack(z=('x', 'y')) + arr = make_xrarray({"w": 2, "x": 3, "y": 4}) + stacked = arr.stack(z=("x", "y")) - z = pd.MultiIndex.from_product( - [np.arange(3), np.arange(4)], - names=['x', 'y']) + z = pd.MultiIndex.from_product([np.arange(3), np.arange(4)], names=["x", "y"]) expected = xr.DataArray( - arr.data.reshape((2, -1)), - {'w': [0, 1], 'z': z}, - dims=['w', 'z']) + arr.data.reshape((2, -1)), {"w": [0, 1], "z": z}, dims=["w", "z"] + ) assert_equal(expected, stacked) @@ -536,50 +696,58 @@ def test_ufuncs(self): def test_dataarray_repr(self): a = xr.DataArray( COO.from_numpy(np.ones(4)), - dims=['x'], - coords={'y': ('x', COO.from_numpy(np.arange(4)))}) - expected = dedent("""\ + dims=["x"], + coords={"y": ("x", COO.from_numpy(np.arange(4)))}, + ) + expected = dedent( + """\ Coordinates: y (x) int64 ... - Dimensions without coordinates: x""") + Dimensions without coordinates: x""" + ) assert expected == repr(a) def test_dataset_repr(self): ds = xr.Dataset( - data_vars={'a': ('x', COO.from_numpy(np.ones(4)))}, - coords={'y': ('x', COO.from_numpy(np.arange(4)))}) - expected = dedent("""\ + data_vars={"a": ("x", COO.from_numpy(np.ones(4)))}, + coords={"y": ("x", COO.from_numpy(np.arange(4)))}, + ) + expected = dedent( + """\ Dimensions: (x: 4) Coordinates: y (x) int64 ... Dimensions without coordinates: x Data variables: - a (x) float64 ...""") + a (x) float64 ...""" + ) assert expected == repr(ds) def test_dataarray_pickle(self): a1 = xr.DataArray( COO.from_numpy(np.ones(4)), - dims=['x'], - coords={'y': ('x', COO.from_numpy(np.arange(4)))}) + dims=["x"], + coords={"y": ("x", COO.from_numpy(np.arange(4)))}, + ) a2 = pickle.loads(pickle.dumps(a1)) assert_identical(a1, a2) def test_dataset_pickle(self): ds1 = xr.Dataset( - data_vars={'a': ('x', COO.from_numpy(np.ones(4)))}, - coords={'y': ('x', COO.from_numpy(np.arange(4)))}) + data_vars={"a": ("x", COO.from_numpy(np.ones(4)))}, + coords={"y": ("x", COO.from_numpy(np.arange(4)))}, + ) ds2 = pickle.loads(pickle.dumps(ds1)) assert_identical(ds1, ds2) def test_coarsen(self): a1 = self.ds_xr a2 = self.sp_xr - m1 = a1.coarsen(x=2, boundary='trim').mean() - m2 = a2.coarsen(x=2, boundary='trim').mean() + m1 = a1.coarsen(x=2, boundary="trim").mean() + m2 = a2.coarsen(x=2, boundary="trim").mean() assert isinstance(m2.data, sparse.SparseArray) assert np.allclose(m1.data, m2.data.todense()) @@ -614,33 +782,36 @@ def test_dot(self): def test_groupby(self): x1 = self.ds_xr x2 = self.sp_xr - m1 = x1.groupby('x').mean(xr.ALL_DIMS) - m2 = x2.groupby('x').mean(xr.ALL_DIMS) + m1 = x1.groupby("x").mean(xr.ALL_DIMS) + m2 = x2.groupby("x").mean(xr.ALL_DIMS) assert isinstance(m2.data, sparse.SparseArray) assert np.allclose(m1.data, m2.data.todense()) @pytest.mark.xfail(reason="Groupby reductions produce dense output") def test_groupby_first(self): x = self.sp_xr.copy() - x.coords['ab'] = ('x', ['a', 'a', 'b', 'b']) - x.groupby('ab').first() - x.groupby('ab').first(skipna=False) + x.coords["ab"] = ("x", ["a", "a", "b", "b"]) + x.groupby("ab").first() + x.groupby("ab").first(skipna=False) @pytest.mark.xfail(reason="Groupby reductions produce dense output") def test_groupby_bins(self): x1 = self.ds_xr x2 = self.sp_xr - m1 = x1.groupby_bins('x', bins=[0, 3, 7, 10]).sum() - m2 = x2.groupby_bins('x', bins=[0, 3, 7, 10]).sum() + m1 = x1.groupby_bins("x", bins=[0, 3, 7, 10]).sum() + m2 = x2.groupby_bins("x", bins=[0, 3, 7, 10]).sum() assert isinstance(m2.data, sparse.SparseArray) assert np.allclose(m1.data, m2.data.todense()) @pytest.mark.xfail(reason="Resample produces dense output") def test_resample(self): - t1 = xr.DataArray(np.linspace(0, 11, num=12), - coords=[pd.date_range('15/12/1999', - periods=12, freq=pd.DateOffset(months=1))], - dims='time') + t1 = xr.DataArray( + np.linspace(0, 11, num=12), + coords=[ + pd.date_range("15/12/1999", periods=12, freq=pd.DateOffset(months=1)) + ], + dims="time", + ) t2 = t1.copy() t2.data = COO(t2.data) m1 = t1.resample(time="QS-DEC").mean() @@ -652,9 +823,11 @@ def test_resample(self): def test_reindex(self): x1 = self.ds_xr x2 = self.sp_xr - for kwargs in [{'x': [2, 3, 4]}, - {'x': [1, 100, 2, 101, 3]}, - {'x': [2.5, 3, 3.5], 'y': [2, 2.5, 3]}]: + for kwargs in [ + {"x": [2, 3, 4]}, + {"x": [1, 100, 2, 101, 3]}, + {"x": [2.5, 3, 3.5], "y": [2, 2.5, 3]}, + ]: m1 = x1.reindex(**kwargs) m2 = x2.reindex(**kwargs) assert np.allclose(m1, m2, equal_nan=True) @@ -662,7 +835,7 @@ def test_reindex(self): @pytest.mark.xfail def test_merge(self): x = self.sp_xr - y = xr.merge([x, x.rename('bar')]).to_array() + y = xr.merge([x, x.rename("bar")]).to_array() assert isinstance(y, sparse.SparseArray) @pytest.mark.xfail @@ -685,5 +858,6 @@ class TestSparseCoords: def test_sparse_coords(self): xr.DataArray( COO.from_numpy(np.arange(4)), - dims=['x'], - coords={'x': COO.from_numpy([1, 2, 3, 4])}) + dims=["x"], + coords={"x": COO.from_numpy([1, 2, 3, 4])}, + ) diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index 841f4f7c832..9bf84c9edb0 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -12,17 +12,18 @@ class TestLoadDataset: @pytest.fixture(autouse=True) def setUp(self): - self.testfile = 'tiny' - self.testfilepath = os.path.expanduser(os.sep.join( - ('~', '.xarray_tutorial_data', self.testfile))) + self.testfile = "tiny" + self.testfilepath = os.path.expanduser( + os.sep.join(("~", ".xarray_tutorial_data", self.testfile)) + ) with suppress(OSError): - os.remove('{}.nc'.format(self.testfilepath)) + os.remove("{}.nc".format(self.testfilepath)) with suppress(OSError): - os.remove('{}.md5'.format(self.testfilepath)) + os.remove("{}.md5".format(self.testfilepath)) def test_download_from_github(self): ds = tutorial.open_dataset(self.testfile).load() - tiny = DataArray(range(5), name='tiny').to_dataset() + tiny = DataArray(range(5), name="tiny").to_dataset() assert_identical(ds, tiny) def test_download_from_github_load_without_cache(self): diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 2c60fb3861e..dc8ba22f57c 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -21,22 +21,26 @@ def assert_identical(a, b): @requires_np113 def test_unary(): - args = [0, - np.zeros(2), - xr.Variable(['x'], [0, 0]), - xr.DataArray([0, 0], dims='x'), - xr.Dataset({'y': ('x', [0, 0])})] + args = [ + 0, + np.zeros(2), + xr.Variable(["x"], [0, 0]), + xr.DataArray([0, 0], dims="x"), + xr.Dataset({"y": ("x", [0, 0])}), + ] for a in args: assert_identical(a + 1, np.cos(a)) @requires_np113 def test_binary(): - args = [0, - np.zeros(2), - xr.Variable(['x'], [0, 0]), - xr.DataArray([0, 0], dims='x'), - xr.Dataset({'y': ('x', [0, 0])})] + args = [ + 0, + np.zeros(2), + xr.Variable(["x"], [0, 0]), + xr.DataArray([0, 0], dims="x"), + xr.Dataset({"y": ("x", [0, 0])}), + ] for n, t1 in enumerate(args): for t2 in args[n:]: assert_identical(t2 + 1, np.maximum(t1, t2 + 1)) @@ -47,11 +51,13 @@ def test_binary(): @requires_np113 def test_binary_out(): - args = [1, - np.ones(2), - xr.Variable(['x'], [1, 1]), - xr.DataArray([1, 1], dims='x'), - xr.Dataset({'y': ('x', [1, 1])})] + args = [ + 1, + np.ones(2), + xr.Variable(["x"], [1, 1]), + xr.DataArray([1, 1], dims="x"), + xr.Dataset({"y": ("x", [1, 1])}), + ] for arg in args: actual_mantissa, actual_exponent = np.frexp(arg) assert_identical(actual_mantissa, 0.5 * arg) @@ -60,10 +66,10 @@ def test_binary_out(): @requires_np113 def test_groupby(): - ds = xr.Dataset({'a': ('x', [0, 0, 0])}, {'c': ('x', [0, 0, 1])}) - ds_grouped = ds.groupby('c') - group_mean = ds_grouped.mean('x') - arr_grouped = ds['a'].groupby('c') + ds = xr.Dataset({"a": ("x", [0, 0, 0])}, {"c": ("x", [0, 0, 1])}) + ds_grouped = ds.groupby("c") + group_mean = ds_grouped.mean("x") + arr_grouped = ds["a"].groupby("c") assert_identical(ds, np.maximum(ds_grouped, group_mean)) assert_identical(ds, np.maximum(group_mean, ds_grouped)) @@ -71,29 +77,30 @@ def test_groupby(): assert_identical(ds, np.maximum(arr_grouped, group_mean)) assert_identical(ds, np.maximum(group_mean, arr_grouped)) - assert_identical(ds, np.maximum(ds_grouped, group_mean['a'])) - assert_identical(ds, np.maximum(group_mean['a'], ds_grouped)) + assert_identical(ds, np.maximum(ds_grouped, group_mean["a"])) + assert_identical(ds, np.maximum(group_mean["a"], ds_grouped)) assert_identical(ds.a, np.maximum(arr_grouped, group_mean.a)) assert_identical(ds.a, np.maximum(group_mean.a, arr_grouped)) - with raises_regex(ValueError, 'mismatched lengths for dimension'): + with raises_regex(ValueError, "mismatched lengths for dimension"): np.maximum(ds.a.variable, ds_grouped) @requires_np113 def test_alignment(): - ds1 = xr.Dataset({'a': ('x', [1, 2])}, {'x': [0, 1]}) - ds2 = xr.Dataset({'a': ('x', [2, 3]), 'b': 4}, {'x': [1, 2]}) + ds1 = xr.Dataset({"a": ("x", [1, 2])}, {"x": [0, 1]}) + ds2 = xr.Dataset({"a": ("x", [2, 3]), "b": 4}, {"x": [1, 2]}) actual = np.add(ds1, ds2) - expected = xr.Dataset({'a': ('x', [4])}, {'x': [1]}) + expected = xr.Dataset({"a": ("x", [4])}, {"x": [1]}) assert_identical_(actual, expected) - with xr.set_options(arithmetic_join='outer'): + with xr.set_options(arithmetic_join="outer"): actual = np.add(ds1, ds2) - expected = xr.Dataset({'a': ('x', [np.nan, 4, np.nan]), 'b': np.nan}, - coords={'x': [0, 1, 2]}) + expected = xr.Dataset( + {"a": ("x", [np.nan, 4, np.nan]), "b": np.nan}, coords={"x": [0, 1, 2]} + ) assert_identical_(actual, expected) @@ -106,21 +113,20 @@ def test_kwargs(): @requires_np113 def test_xarray_defers_to_unrecognized_type(): - class Other: def __array_ufunc__(self, *args, **kwargs): - return 'other' + return "other" xarray_obj = xr.DataArray([1, 2, 3]) other = Other() - assert np.maximum(xarray_obj, other) == 'other' - assert np.sin(xarray_obj, out=other) == 'other' + assert np.maximum(xarray_obj, other) == "other" + assert np.sin(xarray_obj, out=other) == "other" @requires_np113 def test_xarray_handles_dask(): - da = pytest.importorskip('dask.array') - x = xr.DataArray(np.ones((2, 2)), dims=['x', 'y']) + da = pytest.importorskip("dask.array") + x = xr.DataArray(np.ones((2, 2)), dims=["x", "y"]) y = da.ones((2, 2), chunks=(2, 2)) result = np.add(x, y) assert result.chunks == ((2,), (2,)) @@ -129,8 +135,8 @@ def test_xarray_handles_dask(): @requires_np113 def test_dask_defers_to_xarray(): - da = pytest.importorskip('dask.array') - x = xr.DataArray(np.ones((2, 2)), dims=['x', 'y']) + da = pytest.importorskip("dask.array") + x = xr.DataArray(np.ones((2, 2)), dims=["x", "y"]) y = da.ones((2, 2), chunks=(2, 2)) result = np.add(y, x) assert result.chunks == ((2,), (2,)) @@ -140,7 +146,7 @@ def test_dask_defers_to_xarray(): @requires_np113 def test_gufunc_methods(): xarray_obj = xr.DataArray([1, 2, 3]) - with raises_regex(NotImplementedError, 'reduce method'): + with raises_regex(NotImplementedError, "reduce method"): np.add.reduce(xarray_obj, 1) @@ -149,7 +155,7 @@ def test_out(): xarray_obj = xr.DataArray([1, 2, 3]) # xarray out arguments should raise - with raises_regex(NotImplementedError, '`out` argument'): + with raises_regex(NotImplementedError, "`out` argument"): np.add(xarray_obj, 1, out=xarray_obj) # but non-xarray should be OK @@ -161,45 +167,50 @@ def test_out(): @requires_np113 def test_gufuncs(): xarray_obj = xr.DataArray([1, 2, 3]) - fake_gufunc = mock.Mock(signature='(n)->()', autospec=np.sin) - with raises_regex(NotImplementedError, 'generalized ufuncs'): - xarray_obj.__array_ufunc__(fake_gufunc, '__call__', xarray_obj) + fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) + with raises_regex(NotImplementedError, "generalized ufuncs"): + xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) def test_xarray_ufuncs_deprecation(): - with pytest.warns(PendingDeprecationWarning, match='xarray.ufuncs'): + with pytest.warns(PendingDeprecationWarning, match="xarray.ufuncs"): xu.cos(xr.DataArray([0, 1])) with pytest.warns(None) as record: xu.angle(xr.DataArray([0, 1])) - record = [el.message for el in record - if el.category == PendingDeprecationWarning] + record = [el.message for el in record if el.category == PendingDeprecationWarning] assert len(record) == 0 @requires_np113 -@pytest.mark.filterwarnings('ignore::RuntimeWarning') +@pytest.mark.filterwarnings("ignore::RuntimeWarning") @pytest.mark.parametrize( - 'name', - [name for name in dir(xu) - if (not name.startswith('_') and hasattr(np, name) - and name not in ['print_function', 'absolute_import', 'division'])] + "name", + [ + name + for name in dir(xu) + if ( + not name.startswith("_") + and hasattr(np, name) + and name not in ["print_function", "absolute_import", "division"] + ) + ], ) def test_numpy_ufuncs(name, request): x = xr.DataArray([1, 1]) np_func = getattr(np, name) - if hasattr(np_func, 'nin') and np_func.nin == 2: + if hasattr(np_func, "nin") and np_func.nin == 2: args = (x, x) else: args = (x,) y = np_func(*args) - if name in ['angle', 'iscomplex']: + if name in ["angle", "iscomplex"]: # these functions need to be handled with __array_function__ protocol assert isinstance(y, np.ndarray) - elif name in ['frexp']: + elif name in ["frexp"]: # np.frexp returns a tuple assert not isinstance(y, xr.DataArray) else: diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ce4c5cc8198..254983364f9 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -10,8 +10,7 @@ from xarray.core import duck_array_ops, utils from xarray.core.utils import either_dict_or_kwargs -from . import ( - assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_dask) +from . import assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_dask from .test_coding_times import _all_cftime_date_types @@ -19,28 +18,29 @@ class TestAlias: def test(self): def new_method(): pass - old_method = utils.alias(new_method, 'old_method') - assert 'deprecated' in old_method.__doc__ - with pytest.warns(Warning, match='deprecated'): + + old_method = utils.alias(new_method, "old_method") + assert "deprecated" in old_method.__doc__ + with pytest.warns(Warning, match="deprecated"): old_method() def test_safe_cast_to_index(): - dates = pd.date_range('2000-01-01', periods=10) + dates = pd.date_range("2000-01-01", periods=10) x = np.arange(5) - td = x * np.timedelta64(1, 'D') + td = x * np.timedelta64(1, "D") for expected, array in [ - (dates, dates.values), - (pd.Index(x, dtype=object), x.astype(object)), - (pd.Index(td), td), - (pd.Index(td, dtype=object), td.astype(object)), + (dates, dates.values), + (pd.Index(x, dtype=object), x.astype(object)), + (pd.Index(td), td), + (pd.Index(td, dtype=object), td.astype(object)), ]: actual = utils.safe_cast_to_index(array) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") def test_safe_cast_to_index_cftimeindex(): date_types = _all_cftime_date_types() for date_type in date_types.values(): @@ -58,7 +58,7 @@ def test_safe_cast_to_index_cftimeindex(): # Test that datetime.datetime objects are never used in a CFTimeIndex -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason="cftime not installed") def test_safe_cast_to_index_datetime_datetime(): dates = [datetime(1, 1, day) for day in range(1, 20)] @@ -70,26 +70,30 @@ def test_safe_cast_to_index_datetime_datetime(): def test_multiindex_from_product_levels(): result = utils.multiindex_from_product_levels( - [pd.Index(['b', 'a']), pd.Index([1, 3, 2])]) + [pd.Index(["b", "a"]), pd.Index([1, 3, 2])] + ) np.testing.assert_array_equal( # compat for pandas < 0.24 - result.codes if hasattr(result, 'codes') else result.labels, - [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]) - np.testing.assert_array_equal(result.levels[0], ['b', 'a']) + result.codes if hasattr(result, "codes") else result.labels, + [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + ) + np.testing.assert_array_equal(result.levels[0], ["b", "a"]) np.testing.assert_array_equal(result.levels[1], [1, 3, 2]) - other = pd.MultiIndex.from_product([['b', 'a'], [1, 3, 2]]) + other = pd.MultiIndex.from_product([["b", "a"], [1, 3, 2]]) np.testing.assert_array_equal(result.values, other.values) def test_multiindex_from_product_levels_non_unique(): result = utils.multiindex_from_product_levels( - [pd.Index(['b', 'a']), pd.Index([1, 1, 2])]) + [pd.Index(["b", "a"]), pd.Index([1, 1, 2])] + ) np.testing.assert_array_equal( # compat for pandas < 0.24 - result.codes if hasattr(result, 'codes') else result.labels, - [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]]) - np.testing.assert_array_equal(result.levels[0], ['b', 'a']) + result.codes if hasattr(result, "codes") else result.labels, + [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]], + ) + np.testing.assert_array_equal(result.levels[0], ["b", "a"]) np.testing.assert_array_equal(result.levels[1], [1, 2]) @@ -98,17 +102,16 @@ def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional # object arrays assert duck_array_ops.array_equiv(0, np.array(0, dtype=object)) - assert duck_array_ops.array_equiv(np.nan, - np.array(np.nan, dtype=object)) + assert duck_array_ops.array_equiv(np.nan, np.array(np.nan, dtype=object)) assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object)) class TestDictionaries: @pytest.fixture(autouse=True) def setup(self): - self.x = {'a': 'A', 'b': 'B'} - self.y = {'c': 'C', 'b': 'B'} - self.z = {'a': 'Z'} + self.x = {"a": "A", "b": "B"} + self.y = {"c": "C", "b": "B"} + self.z = {"a": "Z"} def test_equivalent(self): assert utils.equivalent(0, 0) @@ -128,67 +131,67 @@ def test_unsafe(self): utils.update_safety_check(self.x, self.z) def test_ordered_dict_intersection(self): - assert {'b': 'B'} == \ - utils.ordered_dict_intersection(self.x, self.y) + assert {"b": "B"} == utils.ordered_dict_intersection(self.x, self.y) assert {} == utils.ordered_dict_intersection(self.x, self.z) def test_dict_equiv(self): x = OrderedDict() - x['a'] = 3 - x['b'] = np.array([1, 2, 3]) + x["a"] = 3 + x["b"] = np.array([1, 2, 3]) y = OrderedDict() - y['b'] = np.array([1.0, 2.0, 3.0]) - y['a'] = 3 + y["b"] = np.array([1.0, 2.0, 3.0]) + y["a"] = 3 assert utils.dict_equiv(x, y) # two nparrays are equal - y['b'] = [1, 2, 3] # np.array not the same as a list + y["b"] = [1, 2, 3] # np.array not the same as a list assert utils.dict_equiv(x, y) # nparray == list - x['b'] = [1.0, 2.0, 3.0] + x["b"] = [1.0, 2.0, 3.0] assert utils.dict_equiv(x, y) # list vs. list - x['c'] = None + x["c"] = None assert not utils.dict_equiv(x, y) # new key in x - x['c'] = np.nan - y['c'] = np.nan + x["c"] = np.nan + y["c"] = np.nan assert utils.dict_equiv(x, y) # as intended, nan is nan - x['c'] = np.inf - y['c'] = np.inf + x["c"] = np.inf + y["c"] = np.inf assert utils.dict_equiv(x, y) # inf == inf y = dict(y) assert utils.dict_equiv(x, y) # different dictionary types are fine - y['b'] = 3 * np.arange(3) + y["b"] = 3 * np.arange(3) assert not utils.dict_equiv(x, y) # not equal when arrays differ def test_frozen(self): x = utils.Frozen(self.x) with pytest.raises(TypeError): - x['foo'] = 'bar' + x["foo"] = "bar" with pytest.raises(TypeError): - del x['a'] + del x["a"] with pytest.raises(AttributeError): x.update(self.y) assert x.mapping == self.x - assert repr(x) in ("Frozen({'a': 'A', 'b': 'B'})", - "Frozen({'b': 'B', 'a': 'A'})") + assert repr(x) in ( + "Frozen({'a': 'A', 'b': 'B'})", + "Frozen({'b': 'B', 'a': 'A'})", + ) def test_sorted_keys_dict(self): - x = {'a': 1, 'b': 2, 'c': 3} + x = {"a": 1, "b": 2, "c": 3} y = utils.SortedKeysDict(x) - assert list(y) == ['a', 'b', 'c'] - assert repr(utils.SortedKeysDict()) == \ - "SortedKeysDict({})" + assert list(y) == ["a", "b", "c"] + assert repr(utils.SortedKeysDict()) == "SortedKeysDict({})" def test_repr_object(): - obj = utils.ReprObject('foo') - assert repr(obj) == 'foo' + obj = utils.ReprObject("foo") + assert repr(obj) == "foo" assert isinstance(obj, Hashable) assert not isinstance(obj, str) def test_repr_object_magic_methods(): - o1 = utils.ReprObject('foo') - o2 = utils.ReprObject('foo') - o3 = utils.ReprObject('bar') - o4 = 'foo' + o1 = utils.ReprObject("foo") + o2 = utils.ReprObject("foo") + o3 = utils.ReprObject("bar") + o4 = "foo" assert o1 == o2 assert o1 != o3 assert o1 != o4 @@ -198,23 +201,22 @@ def test_repr_object_magic_methods(): def test_is_remote_uri(): - assert utils.is_remote_uri('http://example.com') - assert utils.is_remote_uri('https://example.com') - assert not utils.is_remote_uri(' http://example.com') - assert not utils.is_remote_uri('example.nc') + assert utils.is_remote_uri("http://example.com") + assert utils.is_remote_uri("https://example.com") + assert not utils.is_remote_uri(" http://example.com") + assert not utils.is_remote_uri("example.nc") def test_is_grib_path(): - assert not utils.is_grib_path('example.nc') - assert not utils.is_grib_path('example.grib ') - assert utils.is_grib_path('example.grib') - assert utils.is_grib_path('example.grib2') - assert utils.is_grib_path('example.grb') - assert utils.is_grib_path('example.grb2') + assert not utils.is_grib_path("example.nc") + assert not utils.is_grib_path("example.grib ") + assert utils.is_grib_path("example.grib") + assert utils.is_grib_path("example.grib2") + assert utils.is_grib_path("example.grb") + assert utils.is_grib_path("example.grb2") class Test_is_uniform_and_sorted: - def test_sorted_uniform(self): assert utils.is_uniform_spaced(np.arange(5)) @@ -235,11 +237,10 @@ def test_relative_tolerance(self): class Test_hashable: - def test_hashable(self): - for v in [False, 1, (2, ), (3, 4), 'four']: + for v in [False, 1, (2,), (3, 4), "four"]: assert utils.hashable(v) - for v in [[5, 6], ['seven', '8'], {9: 'ten'}]: + for v in [[5, 6], ["seven", "8"], {9: "ten"}]: assert not utils.hashable(v) @@ -253,9 +254,9 @@ def test_dask_array_is_scalar(): def test_hidden_key_dict(): - hidden_key = '_hidden_key' - data = {'a': 1, 'b': 2, hidden_key: 3} - data_expected = {'a': 1, 'b': 2} + hidden_key = "_hidden_key" + data = {"a": 1, "b": 2, hidden_key: 3} + data_expected = {"a": 1, "b": 2} hkd = utils.HiddenKeyDict(data, [hidden_key]) assert len(hkd) == 2 assert hidden_key not in hkd @@ -269,13 +270,13 @@ def test_hidden_key_dict(): def test_either_dict_or_kwargs(): - result = either_dict_or_kwargs(dict(a=1), None, 'foo') + result = either_dict_or_kwargs(dict(a=1), None, "foo") expected = dict(a=1) assert result == expected - result = either_dict_or_kwargs(None, dict(a=1), 'foo') + result = either_dict_or_kwargs(None, dict(a=1), "foo") expected = dict(a=1) assert result == expected - with pytest.raises(ValueError, match=r'foo'): - result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo') + with pytest.raises(ValueError, match=r"foo"): + result = either_dict_or_kwargs(dict(a=1), dict(a=1), "foo") diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 3978d1c43c3..43551d62265 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -14,63 +14,76 @@ from xarray.core import dtypes, indexing from xarray.core.common import full_like, ones_like, zeros_like from xarray.core.indexing import ( - BasicIndexer, CopyOnWriteArray, DaskIndexingAdapter, - LazilyOuterIndexedArray, MemoryCachedArray, NumpyIndexingAdapter, - OuterIndexer, PandasIndexAdapter, VectorizedIndexer) + BasicIndexer, + CopyOnWriteArray, + DaskIndexingAdapter, + LazilyOuterIndexedArray, + MemoryCachedArray, + NumpyIndexingAdapter, + OuterIndexer, + PandasIndexAdapter, + VectorizedIndexer, +) from xarray.core.utils import NDArrayMixin from xarray.core.variable import as_compatible_data, as_variable from xarray.tests import requires_bottleneck from . import ( - assert_allclose, assert_array_equal, assert_equal, assert_identical, - raises_regex, requires_dask, source_ndarray) + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + raises_regex, + requires_dask, + source_ndarray, +) class VariableSubclassobjects: def test_properties(self): data = 0.5 * np.arange(10) - v = self.cls(['time'], data, {'foo': 'bar'}) - assert v.dims == ('time',) + v = self.cls(["time"], data, {"foo": "bar"}) + assert v.dims == ("time",) assert_array_equal(v.values, data) assert v.dtype == float assert v.shape == (10,) assert v.size == 10 - assert v.sizes == {'time': 10} + assert v.sizes == {"time": 10} assert v.nbytes == 80 assert v.ndim == 1 assert len(v) == 10 - assert v.attrs == {'foo': 'bar'} + assert v.attrs == {"foo": "bar"} def test_attrs(self): - v = self.cls(['time'], 0.5 * np.arange(10)) + v = self.cls(["time"], 0.5 * np.arange(10)) assert v.attrs == {} - attrs = {'foo': 'bar'} + attrs = {"foo": "bar"} v.attrs = attrs assert v.attrs == attrs assert isinstance(v.attrs, OrderedDict) - v.attrs['foo'] = 'baz' - assert v.attrs['foo'] == 'baz' + v.attrs["foo"] = "baz" + assert v.attrs["foo"] == "baz" def test_getitem_dict(self): - v = self.cls(['x'], np.random.randn(5)) - actual = v[{'x': 0}] + v = self.cls(["x"], np.random.randn(5)) + actual = v[{"x": 0}] expected = v[0] assert_identical(expected, actual) def test_getitem_1d(self): data = np.array([0, 1, 2]) - v = self.cls(['x'], data) + v = self.cls(["x"], data) v_new = v[dict(x=[0, 1])] - assert v_new.dims == ('x', ) + assert v_new.dims == ("x",) assert_array_equal(v_new, data[[0, 1]]) v_new = v[dict(x=slice(None))] - assert v_new.dims == ('x', ) + assert v_new.dims == ("x",) assert_array_equal(v_new, data) - v_new = v[dict(x=Variable('a', [0, 1]))] - assert v_new.dims == ('a', ) + v_new = v[dict(x=Variable("a", [0, 1]))] + assert v_new.dims == ("a",) assert_array_equal(v_new, data[[0, 1]]) v_new = v[dict(x=1)] @@ -79,52 +92,55 @@ def test_getitem_1d(self): # tuple argument v_new = v[slice(None)] - assert v_new.dims == ('x', ) + assert v_new.dims == ("x",) assert_array_equal(v_new, data) def test_getitem_1d_fancy(self): - v = self.cls(['x'], [0, 1, 2]) + v = self.cls(["x"], [0, 1, 2]) # 1d-variable should be indexable by multi-dimensional Variable - ind = Variable(('a', 'b'), [[0, 1], [0, 1]]) + ind = Variable(("a", "b"), [[0, 1], [0, 1]]) v_new = v[ind] - assert v_new.dims == ('a', 'b') - expected = np.array(v._data)[([0, 1], [0, 1]), ] + assert v_new.dims == ("a", "b") + expected = np.array(v._data)[([0, 1], [0, 1]),] # noqa assert_array_equal(v_new, expected) # boolean indexing - ind = Variable(('x', ), [True, False, True]) + ind = Variable(("x",), [True, False, True]) v_new = v[ind] assert_identical(v[[0, 2]], v_new) v_new = v[[True, False, True]] assert_identical(v[[0, 2]], v_new) with raises_regex(IndexError, "Boolean indexer should"): - ind = Variable(('a', ), [True, False, True]) + ind = Variable(("a",), [True, False, True]) v[ind] def test_getitem_with_mask(self): - v = self.cls(['x'], [0, 1, 2]) + v = self.cls(["x"], [0, 1, 2]) assert_identical(v._getitem_with_mask(-1), Variable((), np.nan)) - assert_identical(v._getitem_with_mask([0, -1, 1]), - self.cls(['x'], [0, np.nan, 1])) - assert_identical(v._getitem_with_mask(slice(2)), - self.cls(['x'], [0, 1])) - assert_identical(v._getitem_with_mask([0, -1, 1], fill_value=-99), - self.cls(['x'], [0, -99, 1])) + assert_identical( + v._getitem_with_mask([0, -1, 1]), self.cls(["x"], [0, np.nan, 1]) + ) + assert_identical(v._getitem_with_mask(slice(2)), self.cls(["x"], [0, 1])) + assert_identical( + v._getitem_with_mask([0, -1, 1], fill_value=-99), + self.cls(["x"], [0, -99, 1]), + ) def test_getitem_with_mask_size_zero(self): - v = self.cls(['x'], []) + v = self.cls(["x"], []) assert_identical(v._getitem_with_mask(-1), Variable((), np.nan)) - assert_identical(v._getitem_with_mask([-1, -1, -1]), - self.cls(['x'], [np.nan, np.nan, np.nan])) + assert_identical( + v._getitem_with_mask([-1, -1, -1]), + self.cls(["x"], [np.nan, np.nan, np.nan]), + ) def test_getitem_with_mask_nd_indexer(self): - v = self.cls(['x'], [0, 1, 2]) - indexer = Variable(('x', 'y'), [[0, -1], [-1, 2]]) + v = self.cls(["x"], [0, 1, 2]) + indexer = Variable(("x", "y"), [[0, -1], [-1, 2]]) assert_identical(v._getitem_with_mask(indexer, fill_value=-1), indexer) - def _assertIndexedLikeNDArray(self, variable, expected_value0, - expected_dtype=None): + def _assertIndexedLikeNDArray(self, variable, expected_value0, expected_dtype=None): """Given a 1-dimensional variable, verify that the variable is indexed like a numpy.ndarray. """ @@ -136,7 +152,7 @@ def _assertIndexedLikeNDArray(self, variable, expected_value0, assert variable.identical(variable.copy()) # check value is equal for both ndarray and Variable with warnings.catch_warnings(): - warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") np.testing.assert_equal(variable.values[0], expected_value0) np.testing.assert_equal(variable[0].values, expected_value0) # check type or dtype is consistent for both ndarray and Variable @@ -149,52 +165,47 @@ def _assertIndexedLikeNDArray(self, variable, expected_value0, assert variable[0].values.dtype == expected_dtype def test_index_0d_int(self): - for value, dtype in [(0, np.int_), - (np.int32(0), np.int32)]: - x = self.cls(['x'], [value]) + for value, dtype in [(0, np.int_), (np.int32(0), np.int32)]: + x = self.cls(["x"], [value]) self._assertIndexedLikeNDArray(x, value, dtype) def test_index_0d_float(self): - for value, dtype in [(0.5, np.float_), - (np.float32(0.5), np.float32)]: - x = self.cls(['x'], [value]) + for value, dtype in [(0.5, np.float_), (np.float32(0.5), np.float32)]: + x = self.cls(["x"], [value]) self._assertIndexedLikeNDArray(x, value, dtype) def test_index_0d_string(self): - value = 'foo' - dtype = np.dtype('U3') - x = self.cls(['x'], [value]) + value = "foo" + dtype = np.dtype("U3") + x = self.cls(["x"], [value]) self._assertIndexedLikeNDArray(x, value, dtype) def test_index_0d_datetime(self): d = datetime(2000, 1, 1) - x = self.cls(['x'], [d]) + x = self.cls(["x"], [d]) self._assertIndexedLikeNDArray(x, np.datetime64(d)) - x = self.cls(['x'], [np.datetime64(d)]) - self._assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]') + x = self.cls(["x"], [np.datetime64(d)]) + self._assertIndexedLikeNDArray(x, np.datetime64(d), "datetime64[ns]") - x = self.cls(['x'], pd.DatetimeIndex([d])) - self._assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]') + x = self.cls(["x"], pd.DatetimeIndex([d])) + self._assertIndexedLikeNDArray(x, np.datetime64(d), "datetime64[ns]") def test_index_0d_timedelta64(self): td = timedelta(hours=1) - x = self.cls(['x'], [np.timedelta64(td)]) - self._assertIndexedLikeNDArray( - x, np.timedelta64(td), 'timedelta64[ns]') + x = self.cls(["x"], [np.timedelta64(td)]) + self._assertIndexedLikeNDArray(x, np.timedelta64(td), "timedelta64[ns]") - x = self.cls(['x'], pd.to_timedelta([td])) - self._assertIndexedLikeNDArray( - x, np.timedelta64(td), 'timedelta64[ns]') + x = self.cls(["x"], pd.to_timedelta([td])) + self._assertIndexedLikeNDArray(x, np.timedelta64(td), "timedelta64[ns]") def test_index_0d_not_a_time(self): - d = np.datetime64('NaT', 'ns') - x = self.cls(['x'], [d]) + d = np.datetime64("NaT", "ns") + x = self.cls(["x"], [d]) self._assertIndexedLikeNDArray(x, d) def test_index_0d_object(self): - class HashableItemWrapper: def __init__(self, item): self.item = item @@ -206,51 +217,52 @@ def __hash__(self): return hash(self.item) def __repr__(self): - return '%s(item=%r)' % (type(self).__name__, self.item) + return "%s(item=%r)" % (type(self).__name__, self.item) item = HashableItemWrapper((1, 2, 3)) - x = self.cls('x', [item]) + x = self.cls("x", [item]) self._assertIndexedLikeNDArray(x, item, expected_dtype=False) def test_0d_object_array_with_list(self): listarray = np.empty((1,), dtype=object) listarray[0] = [1, 2, 3] - x = self.cls('x', listarray) + x = self.cls("x", listarray) assert_array_equal(x.data, listarray) assert_array_equal(x[0].data, listarray.squeeze()) assert_array_equal(x.squeeze().data, listarray.squeeze()) def test_index_and_concat_datetime(self): # regression test for #125 - date_range = pd.date_range('2011-09-01', periods=10) - for dates in [date_range, date_range.values, - date_range.to_pydatetime()]: - expected = self.cls('t', dates) - for times in [[expected[i] for i in range(10)], - [expected[i:(i + 1)] for i in range(10)], - [expected[[i]] for i in range(10)]]: - actual = Variable.concat(times, 't') + date_range = pd.date_range("2011-09-01", periods=10) + for dates in [date_range, date_range.values, date_range.to_pydatetime()]: + expected = self.cls("t", dates) + for times in [ + [expected[i] for i in range(10)], + [expected[i : (i + 1)] for i in range(10)], + [expected[[i]] for i in range(10)], + ]: + actual = Variable.concat(times, "t") assert expected.dtype == actual.dtype assert_array_equal(expected, actual) def test_0d_time_data(self): # regression test for #105 - x = self.cls('time', pd.date_range('2000-01-01', periods=5)) - expected = np.datetime64('2000-01-01', 'ns') + x = self.cls("time", pd.date_range("2000-01-01", periods=5)) + expected = np.datetime64("2000-01-01", "ns") assert x[0].values == expected def test_datetime64_conversion(self): - times = pd.date_range('2000-01-01', periods=3) + times = pd.date_range("2000-01-01", periods=3) for values, preserve_source in [ (times, True), (times.values, True), - (times.values.astype('datetime64[s]'), False), + (times.values.astype("datetime64[s]"), False), (times.to_pydatetime(), False), ]: - v = self.cls(['t'], values) - assert v.dtype == np.dtype('datetime64[ns]') + v = self.cls(["t"], values) + assert v.dtype == np.dtype("datetime64[ns]") assert_array_equal(v.values, times.values) - assert v.values.dtype == np.dtype('datetime64[ns]') + assert v.values.dtype == np.dtype("datetime64[ns]") same_source = source_ndarray(v.values) is source_ndarray(values) assert preserve_source == same_source @@ -259,32 +271,31 @@ def test_timedelta64_conversion(self): for values, preserve_source in [ (times, True), (times.values, True), - (times.values.astype('timedelta64[s]'), False), + (times.values.astype("timedelta64[s]"), False), (times.to_pytimedelta(), False), ]: - v = self.cls(['t'], values) - assert v.dtype == np.dtype('timedelta64[ns]') + v = self.cls(["t"], values) + assert v.dtype == np.dtype("timedelta64[ns]") assert_array_equal(v.values, times.values) - assert v.values.dtype == np.dtype('timedelta64[ns]') + assert v.values.dtype == np.dtype("timedelta64[ns]") same_source = source_ndarray(v.values) is source_ndarray(values) assert preserve_source == same_source def test_object_conversion(self): data = np.arange(5).astype(str).astype(object) - actual = self.cls('x', data) + actual = self.cls("x", data) assert actual.dtype == data.dtype def test_pandas_data(self): - v = self.cls(['x'], pd.Series([0, 1, 2], index=[3, 2, 1])) + v = self.cls(["x"], pd.Series([0, 1, 2], index=[3, 2, 1])) assert_identical(v, v[[0, 1, 2]]) - v = self.cls(['x'], pd.Index([0, 1, 2])) + v = self.cls(["x"], pd.Index([0, 1, 2])) assert v[0].values == v.values[0] def test_pandas_period_index(self): - v = self.cls(['x'], pd.period_range(start='2000', periods=20, - freq='B')) + v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="B")) v = v.load() # for dask-based Variable - assert v[0] == pd.Period('2000', freq='B') + assert v[0] == pd.Period("2000", freq="B") assert "Period('2000-01-03', 'B')" in repr(v) def test_1d_math(self): @@ -293,7 +304,7 @@ def test_1d_math(self): # should we need `.to_base_variable()`? # probably a break that `+v` changes type? - v = self.cls(['x'], x) + v = self.cls(["x"], x) base_v = v.to_base_variable() # unary ops assert_identical(base_v, +v) @@ -309,12 +320,12 @@ def test_1d_math(self): assert_array_equal(v - y, v - 1) assert_array_equal(y - v, 1 - v) # verify attributes are dropped - v2 = self.cls(['x'], x, {'units': 'meters'}) + v2 = self.cls(["x"], x, {"units": "meters"}) assert_identical(base_v, +v2) # binary ops with all variables assert_array_equal(v + v, 2 * v) - w = self.cls(['x'], y, {'foo': 'bar'}) - assert_identical(v + w, self.cls(['x'], x + y).to_base_variable()) + w = self.cls(["x"], y, {"foo": "bar"}) + assert_identical(v + w, self.cls(["x"], x + y).to_base_variable()) assert_array_equal((v * w).values, x * y) # something complicated @@ -332,7 +343,7 @@ def test_1d_math(self): def test_1d_reduce(self): x = np.arange(5) - v = self.cls(['x'], x) + v = self.cls(["x"], x) actual = v.sum() expected = Variable((), 10) assert_identical(expected, actual) @@ -340,27 +351,27 @@ def test_1d_reduce(self): def test_array_interface(self): x = np.arange(5) - v = self.cls(['x'], x) + v = self.cls(["x"], x) assert_array_equal(np.asarray(v), x) # test patched in methods assert_array_equal(v.astype(float), x.astype(float)) # think this is a break, that argsort changes the type assert_identical(v.argsort(), v.to_base_variable()) - assert_identical(v.clip(2, 3), - self.cls('x', x.clip(2, 3)).to_base_variable()) + assert_identical(v.clip(2, 3), self.cls("x", x.clip(2, 3)).to_base_variable()) # test ufuncs - assert_identical(np.sin(v), - self.cls(['x'], np.sin(x)).to_base_variable()) + assert_identical(np.sin(v), self.cls(["x"], np.sin(x)).to_base_variable()) assert isinstance(np.sin(v), Variable) assert not isinstance(np.sin(v), IndexVariable) def example_1d_objects(self): - for data in [range(3), - 0.5 * np.arange(3), - 0.5 * np.arange(3, dtype=np.float32), - pd.date_range('2000-01-01', periods=3), - np.array(['a', 'b', 'c'], dtype=object)]: - yield (self.cls('x', data), data) + for data in [ + range(3), + 0.5 * np.arange(3), + 0.5 * np.arange(3, dtype=np.float32), + pd.date_range("2000-01-01", periods=3), + np.array(["a", "b", "c"], dtype=object), + ]: + yield (self.cls("x", data), data) def test___array__(self): for v, data in self.example_1d_objects(): @@ -385,118 +396,120 @@ def test_equals_all_dtypes(self): def test_eq_all_dtypes(self): # ensure that we don't choke on comparisons for which numpy returns # scalars - expected = Variable('x', 3 * [False]) + expected = Variable("x", 3 * [False]) for v, _ in self.example_1d_objects(): - actual = 'z' == v + actual = "z" == v assert_identical(expected, actual) - actual = ~('z' != v) + actual = ~("z" != v) assert_identical(expected, actual) def test_encoding_preserved(self): - expected = self.cls('x', range(3), {'foo': 1}, {'bar': 2}) - for actual in [expected.T, - expected[...], - expected.squeeze(), - expected.isel(x=slice(None)), - expected.set_dims({'x': 3}), - expected.copy(deep=True), - expected.copy(deep=False)]: - - assert_identical(expected.to_base_variable(), - actual.to_base_variable()) + expected = self.cls("x", range(3), {"foo": 1}, {"bar": 2}) + for actual in [ + expected.T, + expected[...], + expected.squeeze(), + expected.isel(x=slice(None)), + expected.set_dims({"x": 3}), + expected.copy(deep=True), + expected.copy(deep=False), + ]: + + assert_identical(expected.to_base_variable(), actual.to_base_variable()) assert expected.encoding == actual.encoding def test_concat(self): x = np.arange(5) y = np.arange(5, 10) - v = self.cls(['a'], x) - w = self.cls(['a'], y) - assert_identical(Variable(['b', 'a'], np.array([x, y])), - Variable.concat([v, w], 'b')) - assert_identical(Variable(['b', 'a'], np.array([x, y])), - Variable.concat((v, w), 'b')) - assert_identical(Variable(['b', 'a'], np.array([x, y])), - Variable.concat((v, w), 'b')) - with raises_regex(ValueError, 'inconsistent dimensions'): - Variable.concat([v, Variable(['c'], y)], 'b') + v = self.cls(["a"], x) + w = self.cls(["a"], y) + assert_identical( + Variable(["b", "a"], np.array([x, y])), Variable.concat([v, w], "b") + ) + assert_identical( + Variable(["b", "a"], np.array([x, y])), Variable.concat((v, w), "b") + ) + assert_identical( + Variable(["b", "a"], np.array([x, y])), Variable.concat((v, w), "b") + ) + with raises_regex(ValueError, "inconsistent dimensions"): + Variable.concat([v, Variable(["c"], y)], "b") # test indexers actual = Variable.concat( - [v, w], - positions=[np.arange(0, 10, 2), np.arange(1, 10, 2)], - dim='a') - expected = Variable('a', np.array([x, y]).ravel(order='F')) + [v, w], positions=[np.arange(0, 10, 2), np.arange(1, 10, 2)], dim="a" + ) + expected = Variable("a", np.array([x, y]).ravel(order="F")) assert_identical(expected, actual) # test concatenating along a dimension - v = Variable(['time', 'x'], np.random.random((10, 8))) - assert_identical(v, Variable.concat([v[:5], v[5:]], 'time')) - assert_identical(v, Variable.concat([v[:5], v[5:6], v[6:]], 'time')) - assert_identical(v, Variable.concat([v[:1], v[1:]], 'time')) + v = Variable(["time", "x"], np.random.random((10, 8))) + assert_identical(v, Variable.concat([v[:5], v[5:]], "time")) + assert_identical(v, Variable.concat([v[:5], v[5:6], v[6:]], "time")) + assert_identical(v, Variable.concat([v[:1], v[1:]], "time")) # test dimension order - assert_identical(v, Variable.concat([v[:, :5], v[:, 5:]], 'x')) - with raises_regex(ValueError, 'all input arrays must have'): - Variable.concat([v[:, 0], v[:, 1:]], 'x') + assert_identical(v, Variable.concat([v[:, :5], v[:, 5:]], "x")) + with raises_regex(ValueError, "all input arrays must have"): + Variable.concat([v[:, 0], v[:, 1:]], "x") def test_concat_attrs(self): # different or conflicting attributes should be removed - v = self.cls('a', np.arange(5), {'foo': 'bar'}) - w = self.cls('a', np.ones(5)) + v = self.cls("a", np.arange(5), {"foo": "bar"}) + w = self.cls("a", np.ones(5)) expected = self.cls( - 'a', np.concatenate([np.arange(5), np.ones(5)])).to_base_variable() - assert_identical(expected, Variable.concat([v, w], 'a')) - w.attrs['foo'] = 2 - assert_identical(expected, Variable.concat([v, w], 'a')) - w.attrs['foo'] = 'bar' - expected.attrs['foo'] = 'bar' - assert_identical(expected, Variable.concat([v, w], 'a')) + "a", np.concatenate([np.arange(5), np.ones(5)]) + ).to_base_variable() + assert_identical(expected, Variable.concat([v, w], "a")) + w.attrs["foo"] = 2 + assert_identical(expected, Variable.concat([v, w], "a")) + w.attrs["foo"] = "bar" + expected.attrs["foo"] = "bar" + assert_identical(expected, Variable.concat([v, w], "a")) def test_concat_fixed_len_str(self): # regression test for #217 - for kind in ['S', 'U']: - x = self.cls('animal', np.array(['horse'], dtype=kind)) - y = self.cls('animal', np.array(['aardvark'], dtype=kind)) - actual = Variable.concat([x, y], 'animal') - expected = Variable( - 'animal', np.array(['horse', 'aardvark'], dtype=kind)) + for kind in ["S", "U"]: + x = self.cls("animal", np.array(["horse"], dtype=kind)) + y = self.cls("animal", np.array(["aardvark"], dtype=kind)) + actual = Variable.concat([x, y], "animal") + expected = Variable("animal", np.array(["horse", "aardvark"], dtype=kind)) assert_equal(expected, actual) def test_concat_number_strings(self): # regression test for #305 - a = self.cls('x', ['0', '1', '2']) - b = self.cls('x', ['3', '4']) - actual = Variable.concat([a, b], dim='x') - expected = Variable('x', np.arange(5).astype(str)) + a = self.cls("x", ["0", "1", "2"]) + b = self.cls("x", ["3", "4"]) + actual = Variable.concat([a, b], dim="x") + expected = Variable("x", np.arange(5).astype(str)) assert_identical(expected, actual) assert actual.dtype.kind == expected.dtype.kind def test_concat_mixed_dtypes(self): - a = self.cls('x', [0, 1]) - b = self.cls('x', ['two']) - actual = Variable.concat([a, b], dim='x') - expected = Variable('x', np.array([0, 1, 'two'], dtype=object)) + a = self.cls("x", [0, 1]) + b = self.cls("x", ["two"]) + actual = Variable.concat([a, b], dim="x") + expected = Variable("x", np.array([0, 1, "two"], dtype=object)) assert_identical(expected, actual) assert actual.dtype == object - @pytest.mark.parametrize('deep', [True, False]) - @pytest.mark.parametrize('astype', [float, int, str]) + @pytest.mark.parametrize("deep", [True, False]) + @pytest.mark.parametrize("astype", [float, int, str]) def test_copy(self, deep, astype): - v = self.cls('x', (0.5 * np.arange(10)).astype(astype), {'foo': 'bar'}) + v = self.cls("x", (0.5 * np.arange(10)).astype(astype), {"foo": "bar"}) w = v.copy(deep=deep) assert type(v) is type(w) assert_identical(v, w) assert v.dtype == w.dtype if self.cls is Variable: if deep: - assert (source_ndarray(v.values) is not - source_ndarray(w.values)) + assert source_ndarray(v.values) is not source_ndarray(w.values) else: - assert (source_ndarray(v.values) is - source_ndarray(w.values)) + assert source_ndarray(v.values) is source_ndarray(w.values) assert_identical(v, copy(v)) def test_copy_index(self): - midx = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], - names=('one', 'two', 'three')) - v = self.cls('x', midx) + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + v = self.cls("x", midx) for deep in [True, False]: w = v.copy(deep=deep) assert isinstance(w._data, PandasIndexAdapter) @@ -504,7 +517,7 @@ def test_copy_index(self): assert_array_equal(v._data.array, w._data.array) def test_copy_with_data(self): - orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = np.array([[2.5, 5.0], [7.1, 43]]) actual = orig.copy(data=new_data) expected = orig.copy() @@ -512,13 +525,13 @@ def test_copy_with_data(self): assert_identical(expected, actual) def test_copy_with_data_errors(self): - orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = [2.5, 5.0] - with raises_regex(ValueError, 'must match shape of object'): + with raises_regex(ValueError, "must match shape of object"): orig.copy(data=new_data) def test_copy_index_with_data(self): - orig = IndexVariable('x', np.arange(5)) + orig = IndexVariable("x", np.arange(5)) new_data = np.arange(5, 10) actual = orig.copy(data=new_data) expected = orig.copy() @@ -526,53 +539,55 @@ def test_copy_index_with_data(self): assert_identical(expected, actual) def test_copy_index_with_data_errors(self): - orig = IndexVariable('x', np.arange(5)) + orig = IndexVariable("x", np.arange(5)) new_data = np.arange(5, 20) - with raises_regex(ValueError, 'must match shape of object'): + with raises_regex(ValueError, "must match shape of object"): orig.copy(data=new_data) def test_real_and_imag(self): - v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'}) - expected_re = self.cls('x', np.arange(3), {'foo': 'bar'}) + v = self.cls("x", np.arange(3) - 1j * np.arange(3), {"foo": "bar"}) + expected_re = self.cls("x", np.arange(3), {"foo": "bar"}) assert_identical(v.real, expected_re) - expected_im = self.cls('x', -np.arange(3), {'foo': 'bar'}) + expected_im = self.cls("x", -np.arange(3), {"foo": "bar"}) assert_identical(v.imag, expected_im) - expected_abs = self.cls( - 'x', np.sqrt(2 * np.arange(3) ** 2)).to_base_variable() + expected_abs = self.cls("x", np.sqrt(2 * np.arange(3) ** 2)).to_base_variable() assert_allclose(abs(v), expected_abs) def test_aggregate_complex(self): # should skip NaNs - v = self.cls('x', [1, 2j, np.nan]) + v = self.cls("x", [1, 2j, np.nan]) expected = Variable((), 0.5 + 1j) assert_allclose(v.mean(), expected) def test_pandas_cateogrical_dtype(self): - data = pd.Categorical(np.arange(10, dtype='int64')) - v = self.cls('x', data) + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) print(v) # should not error - assert v.dtype == 'int64' + assert v.dtype == "int64" def test_pandas_datetime64_with_tz(self): - data = pd.date_range(start='2000-01-01', - tz=pytz.timezone('America/New_York'), - periods=10, freq='1h') - v = self.cls('x', data) + data = pd.date_range( + start="2000-01-01", + tz=pytz.timezone("America/New_York"), + periods=10, + freq="1h", + ) + v = self.cls("x", data) print(v) # should not error - if 'America/New_York' in str(data.dtype): + if "America/New_York" in str(data.dtype): # pandas is new enough that it has datetime64 with timezone dtype - assert v.dtype == 'object' + assert v.dtype == "object" def test_multiindex(self): - idx = pd.MultiIndex.from_product([list('abc'), [0, 1]]) - v = self.cls('x', idx) - assert_identical(Variable((), ('a', 0)), v[0]) + idx = pd.MultiIndex.from_product([list("abc"), [0, 1]]) + v = self.cls("x", idx) + assert_identical(Variable((), ("a", 0)), v[0]) assert_identical(v, v[:]) def test_load(self): - array = self.cls('x', np.arange(5)) + array = self.cls("x", np.arange(5)) orig_data = array._data copied = array.copy(deep=True) if array.chunks is None: @@ -582,27 +597,27 @@ def test_load(self): assert_identical(array, copied) def test_getitem_advanced(self): - v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]]) + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) v_data = v.compute().data # orthogonal indexing v_new = v[([0, 1], [1, 0])] - assert v_new.dims == ('x', 'y') + assert v_new.dims == ("x", "y") assert_array_equal(v_new, v_data[[0, 1]][:, [1, 0]]) v_new = v[[0, 1]] - assert v_new.dims == ('x', 'y') + assert v_new.dims == ("x", "y") assert_array_equal(v_new, v_data[[0, 1]]) # with mixed arguments - ind = Variable(['a'], [0, 1]) + ind = Variable(["a"], [0, 1]) v_new = v[dict(x=[0, 1], y=ind)] - assert v_new.dims == ('x', 'a') + assert v_new.dims == ("x", "a") assert_array_equal(v_new, v_data[[0, 1]][:, [0, 1]]) # boolean indexing v_new = v[dict(x=[True, False], y=[False, True, False])] - assert v_new.dims == ('x', 'y') + assert v_new.dims == ("x", "y") assert_array_equal(v_new, v_data[0][1]) # with scalar variable @@ -613,17 +628,17 @@ def test_getitem_advanced(self): # with boolean variable with wrong shape ind = np.array([True, False]) - with raises_regex(IndexError, 'Boolean array size 2 is '): - v[Variable(('a', 'b'), [[0, 1]]), ind] + with raises_regex(IndexError, "Boolean array size 2 is "): + v[Variable(("a", "b"), [[0, 1]]), ind] # boolean indexing with different dimension - ind = Variable(['a'], [True, False, False]) - with raises_regex(IndexError, 'Boolean indexer should be'): + ind = Variable(["a"], [True, False, False]) + with raises_regex(IndexError, "Boolean indexer should be"): v[dict(y=ind)] def test_getitem_uint_1d(self): # regression test for #1405 - v = self.cls(['x'], [0, 1, 2]) + v = self.cls(["x"], [0, 1, 2]) v_data = v.compute().data v_new = v[np.array([0])] @@ -633,7 +648,7 @@ def test_getitem_uint_1d(self): def test_getitem_uint(self): # regression test for #1405 - v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]]) + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) v_data = v.compute().data v_new = v[np.array([0])] @@ -646,7 +661,7 @@ def test_getitem_uint(self): def test_getitem_0d_array(self): # make sure 0d-np.array can be used as an indexer - v = self.cls(['x'], [0, 1, 2]) + v = self.cls(["x"], [0, 1, 2]) v_data = v.compute().data v_new = v[np.array([0])[0]] @@ -659,84 +674,84 @@ def test_getitem_0d_array(self): assert_array_equal(v_new, v_data[0]) def test_getitem_fancy(self): - v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]]) + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) v_data = v.compute().data - ind = Variable(['a', 'b'], [[0, 1, 1], [1, 1, 0]]) + ind = Variable(["a", "b"], [[0, 1, 1], [1, 1, 0]]) v_new = v[ind] - assert v_new.dims == ('a', 'b', 'y') + assert v_new.dims == ("a", "b", "y") assert_array_equal(v_new, v_data[[[0, 1, 1], [1, 1, 0]], :]) # It would be ok if indexed with the multi-dimensional array including # the same name - ind = Variable(['x', 'b'], [[0, 1, 1], [1, 1, 0]]) + ind = Variable(["x", "b"], [[0, 1, 1], [1, 1, 0]]) v_new = v[ind] - assert v_new.dims == ('x', 'b', 'y') + assert v_new.dims == ("x", "b", "y") assert_array_equal(v_new, v_data[[[0, 1, 1], [1, 1, 0]], :]) - ind = Variable(['a', 'b'], [[0, 1, 2], [2, 1, 0]]) + ind = Variable(["a", "b"], [[0, 1, 2], [2, 1, 0]]) v_new = v[dict(y=ind)] - assert v_new.dims == ('x', 'a', 'b') + assert v_new.dims == ("x", "a", "b") assert_array_equal(v_new, v_data[:, ([0, 1, 2], [2, 1, 0])]) - ind = Variable(['a', 'b'], [[0, 0], [1, 1]]) + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) v_new = v[dict(x=[1, 0], y=ind)] - assert v_new.dims == ('x', 'a', 'b') + assert v_new.dims == ("x", "a", "b") assert_array_equal(v_new, v_data[[1, 0]][:, ind]) # along diagonal - ind = Variable(['a'], [0, 1]) + ind = Variable(["a"], [0, 1]) v_new = v[ind, ind] - assert v_new.dims == ('a',) + assert v_new.dims == ("a",) assert_array_equal(v_new, v_data[[0, 1], [0, 1]]) # with integer - ind = Variable(['a', 'b'], [[0, 0], [1, 1]]) + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) v_new = v[dict(x=0, y=ind)] - assert v_new.dims == ('a', 'b') + assert v_new.dims == ("a", "b") assert_array_equal(v_new[0], v_data[0][[0, 0]]) assert_array_equal(v_new[1], v_data[0][[1, 1]]) # with slice - ind = Variable(['a', 'b'], [[0, 0], [1, 1]]) + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) v_new = v[dict(x=slice(None), y=ind)] - assert v_new.dims == ('x', 'a', 'b') + assert v_new.dims == ("x", "a", "b") assert_array_equal(v_new, v_data[:, [[0, 0], [1, 1]]]) - ind = Variable(['a', 'b'], [[0, 0], [1, 1]]) + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) v_new = v[dict(x=ind, y=slice(None))] - assert v_new.dims == ('a', 'b', 'y') + assert v_new.dims == ("a", "b", "y") assert_array_equal(v_new, v_data[[[0, 0], [1, 1]], :]) - ind = Variable(['a', 'b'], [[0, 0], [1, 1]]) + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) v_new = v[dict(x=ind, y=slice(None, 1))] - assert v_new.dims == ('a', 'b', 'y') + assert v_new.dims == ("a", "b", "y") assert_array_equal(v_new, v_data[[[0, 0], [1, 1]], slice(None, 1)]) # slice matches explicit dimension - ind = Variable(['y'], [0, 1]) + ind = Variable(["y"], [0, 1]) v_new = v[ind, :2] - assert v_new.dims == ('y',) + assert v_new.dims == ("y",) assert_array_equal(v_new, v_data[[0, 1], [0, 1]]) # with multiple slices - v = self.cls(['x', 'y', 'z'], [[[1, 2, 3], [4, 5, 6]]]) - ind = Variable(['a', 'b'], [[0]]) + v = self.cls(["x", "y", "z"], [[[1, 2, 3], [4, 5, 6]]]) + ind = Variable(["a", "b"], [[0]]) v_new = v[ind, :, :] - expected = Variable(['a', 'b', 'y', 'z'], v.data[np.newaxis, ...]) + expected = Variable(["a", "b", "y", "z"], v.data[np.newaxis, ...]) assert_identical(v_new, expected) - v = Variable(['w', 'x', 'y', 'z'], [[[[1, 2, 3], [4, 5, 6]]]]) - ind = Variable(['y'], [0]) + v = Variable(["w", "x", "y", "z"], [[[[1, 2, 3], [4, 5, 6]]]]) + ind = Variable(["y"], [0]) v_new = v[ind, :, 1:2, 2] - expected = Variable(['y', 'x'], [[6]]) + expected = Variable(["y", "x"], [[6]]) assert_identical(v_new, expected) # slice and vector mixed indexing resulting in the same dimension - v = Variable(['x', 'y', 'z'], np.arange(60).reshape(3, 4, 5)) - ind = Variable(['x'], [0, 1, 2]) + v = Variable(["x", "y", "z"], np.arange(60).reshape(3, 4, 5)) + ind = Variable(["x"], [0, 1, 2]) v_new = v[:, ind] - expected = Variable(('x', 'z'), np.zeros((3, 5))) + expected = Variable(("x", "z"), np.zeros((3, 5))) expected[0] = v.data[0, 0] expected[1] = v.data[1, 1] expected[2] = v.data[2, 2] @@ -746,64 +761,70 @@ def test_getitem_fancy(self): assert v_new.shape == (3, 3, 5) def test_getitem_error(self): - v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]]) + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) with raises_regex(IndexError, "labeled multi-"): v[[[0, 1], [1, 2]]] - ind_x = Variable(['a'], [0, 1, 1]) - ind_y = Variable(['a'], [0, 1]) + ind_x = Variable(["a"], [0, 1, 1]) + ind_y = Variable(["a"], [0, 1]) with raises_regex(IndexError, "Dimensions of indexers "): v[ind_x, ind_y] - ind = Variable(['a', 'b'], [[True, False], [False, True]]) - with raises_regex(IndexError, '2-dimensional boolean'): + ind = Variable(["a", "b"], [[True, False], [False, True]]) + with raises_regex(IndexError, "2-dimensional boolean"): v[dict(x=ind)] - v = Variable(['x', 'y', 'z'], np.arange(60).reshape(3, 4, 5)) - ind = Variable(['x'], [0, 1]) - with raises_regex(IndexError, 'Dimensions of indexers mis'): + v = Variable(["x", "y", "z"], np.arange(60).reshape(3, 4, 5)) + ind = Variable(["x"], [0, 1]) + with raises_regex(IndexError, "Dimensions of indexers mis"): v[:, ind] def test_pad(self): data = np.arange(4 * 3 * 2).reshape(4, 3, 2) - v = self.cls(['x', 'y', 'z'], data) - - xr_args = [{'x': (2, 1)}, {'y': (0, 3)}, {'x': (3, 1), 'z': (2, 0)}] - np_args = [((2, 1), (0, 0), (0, 0)), ((0, 0), (0, 3), (0, 0)), - ((3, 1), (0, 0), (2, 0))] + v = self.cls(["x", "y", "z"], data) + + xr_args = [{"x": (2, 1)}, {"y": (0, 3)}, {"x": (3, 1), "z": (2, 0)}] + np_args = [ + ((2, 1), (0, 0), (0, 0)), + ((0, 0), (0, 3), (0, 0)), + ((3, 1), (0, 0), (2, 0)), + ] for xr_arg, np_arg in zip(xr_args, np_args): actual = v.pad_with_fill_value(**xr_arg) - expected = np.pad(np.array(v.data.astype(float)), np_arg, - mode='constant', constant_values=np.nan) + expected = np.pad( + np.array(v.data.astype(float)), + np_arg, + mode="constant", + constant_values=np.nan, + ) assert_array_equal(actual, expected) assert isinstance(actual._data, type(v._data)) # for the boolean array, we pad False data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2) - v = self.cls(['x', 'y', 'z'], data) + v = self.cls(["x", "y", "z"], data) for xr_arg, np_arg in zip(xr_args, np_args): actual = v.pad_with_fill_value(fill_value=False, **xr_arg) - expected = np.pad(np.array(v.data), np_arg, - mode='constant', constant_values=False) + expected = np.pad( + np.array(v.data), np_arg, mode="constant", constant_values=False + ) assert_array_equal(actual, expected) def test_rolling_window(self): # Just a working test. See test_nputils for the algorithm validation - v = self.cls(['x', 'y', 'z'], - np.arange(40 * 30 * 2).reshape(40, 30, 2)) - for (d, w) in [('x', 3), ('y', 5)]: - v_rolling = v.rolling_window(d, w, d + '_window') - assert v_rolling.dims == ('x', 'y', 'z', d + '_window') - assert v_rolling.shape == v.shape + (w, ) + v = self.cls(["x", "y", "z"], np.arange(40 * 30 * 2).reshape(40, 30, 2)) + for (d, w) in [("x", 3), ("y", 5)]: + v_rolling = v.rolling_window(d, w, d + "_window") + assert v_rolling.dims == ("x", "y", "z", d + "_window") + assert v_rolling.shape == v.shape + (w,) - v_rolling = v.rolling_window(d, w, d + '_window', center=True) - assert v_rolling.dims == ('x', 'y', 'z', d + '_window') - assert v_rolling.shape == v.shape + (w, ) + v_rolling = v.rolling_window(d, w, d + "_window", center=True) + assert v_rolling.dims == ("x", "y", "z", d + "_window") + assert v_rolling.shape == v.shape + (w,) # dask and numpy result should be the same - v_loaded = v.load().rolling_window(d, w, d + '_window', - center=True) + v_loaded = v.load().rolling_window(d, w, d + "_window", center=True) assert_array_equal(v_rolling, v_loaded) # numpy backend should not be over-written @@ -820,7 +841,7 @@ def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) def test_data_and_values(self): - v = Variable(['time', 'x'], self.d) + v = Variable(["time", "x"], self.d) assert_array_equal(v.data, self.d) assert_array_equal(v.values, self.d) assert source_ndarray(v.values) is self.d @@ -839,67 +860,65 @@ def test_numpy_same_methods(self): assert v.item() == 0 assert type(v.item()) is float - v = IndexVariable('x', np.arange(5)) + v = IndexVariable("x", np.arange(5)) assert 2 == v.searchsorted(2) def test_datetime64_conversion_scalar(self): - expected = np.datetime64('2000-01-01', 'ns') + expected = np.datetime64("2000-01-01", "ns") for values in [ - np.datetime64('2000-01-01'), - pd.Timestamp('2000-01-01T00'), + np.datetime64("2000-01-01"), + pd.Timestamp("2000-01-01T00"), datetime(2000, 1, 1), ]: v = Variable([], values) - assert v.dtype == np.dtype('datetime64[ns]') + assert v.dtype == np.dtype("datetime64[ns]") assert v.values == expected - assert v.values.dtype == np.dtype('datetime64[ns]') + assert v.values.dtype == np.dtype("datetime64[ns]") def test_timedelta64_conversion_scalar(self): - expected = np.timedelta64(24 * 60 * 60 * 10 ** 9, 'ns') + expected = np.timedelta64(24 * 60 * 60 * 10 ** 9, "ns") for values in [ - np.timedelta64(1, 'D'), - pd.Timedelta('1 day'), + np.timedelta64(1, "D"), + pd.Timedelta("1 day"), timedelta(days=1), ]: v = Variable([], values) - assert v.dtype == np.dtype('timedelta64[ns]') + assert v.dtype == np.dtype("timedelta64[ns]") assert v.values == expected - assert v.values.dtype == np.dtype('timedelta64[ns]') + assert v.values.dtype == np.dtype("timedelta64[ns]") def test_0d_str(self): - v = Variable([], 'foo') - assert v.dtype == np.dtype('U3') - assert v.values == 'foo' + v = Variable([], "foo") + assert v.dtype == np.dtype("U3") + assert v.values == "foo" - v = Variable([], np.string_('foo')) - assert v.dtype == np.dtype('S3') - assert v.values == bytes('foo', 'ascii') + v = Variable([], np.string_("foo")) + assert v.dtype == np.dtype("S3") + assert v.values == bytes("foo", "ascii") def test_0d_datetime(self): - v = Variable([], pd.Timestamp('2000-01-01')) - assert v.dtype == np.dtype('datetime64[ns]') - assert v.values == np.datetime64('2000-01-01', 'ns') + v = Variable([], pd.Timestamp("2000-01-01")) + assert v.dtype == np.dtype("datetime64[ns]") + assert v.values == np.datetime64("2000-01-01", "ns") def test_0d_timedelta(self): - for td in [pd.to_timedelta('1s'), np.timedelta64(1, 's')]: + for td in [pd.to_timedelta("1s"), np.timedelta64(1, "s")]: v = Variable([], td) - assert v.dtype == np.dtype('timedelta64[ns]') - assert v.values == np.timedelta64(10 ** 9, 'ns') + assert v.dtype == np.dtype("timedelta64[ns]") + assert v.values == np.timedelta64(10 ** 9, "ns") def test_equals_and_identical(self): d = np.random.rand(10, 3) d[0, 0] = np.nan - v1 = Variable(('dim1', 'dim2'), data=d, - attrs={'att1': 3, 'att2': [1, 2, 3]}) - v2 = Variable(('dim1', 'dim2'), data=d, - attrs={'att1': 3, 'att2': [1, 2, 3]}) + v1 = Variable(("dim1", "dim2"), data=d, attrs={"att1": 3, "att2": [1, 2, 3]}) + v2 = Variable(("dim1", "dim2"), data=d, attrs={"att1": 3, "att2": [1, 2, 3]}) assert v1.equals(v2) assert v1.identical(v2) - v3 = Variable(('dim1', 'dim3'), data=d) + v3 = Variable(("dim1", "dim3"), data=d) assert not v1.equals(v3) - v4 = Variable(('dim1', 'dim2'), data=d) + v4 = Variable(("dim1", "dim2"), data=d) assert v1.equals(v4) assert not v1.identical(v4) @@ -915,24 +934,24 @@ def test_equals_and_identical(self): def test_broadcast_equals(self): v1 = Variable((), np.nan) - v2 = Variable(('x'), [np.nan, np.nan]) + v2 = Variable(("x"), [np.nan, np.nan]) assert v1.broadcast_equals(v2) assert not v1.equals(v2) assert not v1.identical(v2) - v3 = Variable(('x'), [np.nan]) + v3 = Variable(("x"), [np.nan]) assert v1.broadcast_equals(v3) assert not v1.equals(v3) assert not v1.identical(v3) assert not v1.broadcast_equals(None) - v4 = Variable(('x'), [np.nan] * 3) + v4 = Variable(("x"), [np.nan] * 3) assert not v2.broadcast_equals(v4) def test_no_conflicts(self): - v1 = Variable(('x'), [1, 2, np.nan, np.nan]) - v2 = Variable(('x'), [np.nan, 2, 3, np.nan]) + v1 = Variable(("x"), [1, 2, np.nan, np.nan]) + v2 = Variable(("x"), [np.nan, 2, 3, np.nan]) assert v1.no_conflicts(v2) assert not v1.equals(v2) assert not v1.broadcast_equals(v2) @@ -940,43 +959,47 @@ def test_no_conflicts(self): assert not v1.no_conflicts(None) - v3 = Variable(('y'), [np.nan, 2, 3, np.nan]) + v3 = Variable(("y"), [np.nan, 2, 3, np.nan]) assert not v3.no_conflicts(v1) d = np.array([1, 2, np.nan, np.nan]) assert not v1.no_conflicts(d) assert not v2.no_conflicts(d) - v4 = Variable(('w', 'x'), [d]) + v4 = Variable(("w", "x"), [d]) assert v1.no_conflicts(v4) def test_as_variable(self): data = np.arange(10) - expected = Variable('x', data) - expected_extra = Variable('x', data, attrs={'myattr': 'val'}, - encoding={'scale_factor': 1}) + expected = Variable("x", data) + expected_extra = Variable( + "x", data, attrs={"myattr": "val"}, encoding={"scale_factor": 1} + ) assert_identical(expected, as_variable(expected)) - ds = Dataset({'x': expected}) - var = as_variable(ds['x']).to_base_variable() + ds = Dataset({"x": expected}) + var = as_variable(ds["x"]).to_base_variable() assert_identical(expected, var) - assert not isinstance(ds['x'], Variable) - assert isinstance(as_variable(ds['x']), Variable) - - xarray_tuple = (expected_extra.dims, expected_extra.values, - expected_extra.attrs, expected_extra.encoding) + assert not isinstance(ds["x"], Variable) + assert isinstance(as_variable(ds["x"]), Variable) + + xarray_tuple = ( + expected_extra.dims, + expected_extra.values, + expected_extra.attrs, + expected_extra.encoding, + ) assert_identical(expected_extra, as_variable(xarray_tuple)) - with raises_regex(TypeError, 'tuple of form'): + with raises_regex(TypeError, "tuple of form"): as_variable(tuple(data)) - with raises_regex(ValueError, 'tuple of form'): # GH1016 - as_variable(('five', 'six', 'seven')) - with raises_regex( - TypeError, 'without an explicit list of dimensions'): + with raises_regex(ValueError, "tuple of form"): # GH1016 + as_variable(("five", "six", "seven")) + with raises_regex(TypeError, "without an explicit list of dimensions"): as_variable(data) - actual = as_variable(data, name='x') + actual = as_variable(data, name="x") assert_identical(expected.to_index_variable(), actual) actual = as_variable(0) @@ -984,41 +1007,40 @@ def test_as_variable(self): assert_identical(expected, actual) data = np.arange(9).reshape((3, 3)) - expected = Variable(('x', 'y'), data) - with raises_regex( - ValueError, 'without explicit dimension names'): - as_variable(data, name='x') - with raises_regex( - ValueError, 'has more than 1-dimension'): - as_variable(expected, name='x') + expected = Variable(("x", "y"), data) + with raises_regex(ValueError, "without explicit dimension names"): + as_variable(data, name="x") + with raises_regex(ValueError, "has more than 1-dimension"): + as_variable(expected, name="x") # test datetime, timedelta conversion - dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) - for x in range(10)]) - assert as_variable(dt, 'time').dtype.kind == 'M' + dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) for x in range(10)]) + assert as_variable(dt, "time").dtype.kind == "M" td = np.array([timedelta(days=x) for x in range(10)]) - assert as_variable(td, 'time').dtype.kind == 'm' + assert as_variable(td, "time").dtype.kind == "m" def test_repr(self): - v = Variable(['time', 'x'], [[1, 2, 3], [4, 5, 6]], {'foo': 'bar'}) - expected = dedent(""" + v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) + expected = dedent( + """ array([[1, 2, 3], [4, 5, 6]]) Attributes: foo: bar - """).strip() + """ + ).strip() assert expected == repr(v) def test_repr_lazy_data(self): - v = Variable('x', LazilyOuterIndexedArray(np.arange(2e5))) - assert '200000 values with dtype' in repr(v) + v = Variable("x", LazilyOuterIndexedArray(np.arange(2e5))) + assert "200000 values with dtype" in repr(v) assert isinstance(v._data, LazilyOuterIndexedArray) def test_detect_indexer_type(self): """ Tests indexer type was correctly detected. """ data = np.random.random((10, 11)) - v = Variable(['x', 'y'], data) + v = Variable(["x", "y"], data) _, ind, _ = v._broadcast_indexes((0, 1)) assert type(ind) == indexing.BasicIndexer @@ -1038,26 +1060,26 @@ def test_detect_indexer_type(self): _, ind, _ = v._broadcast_indexes(([0, 1], slice(0, 8, 2))) assert type(ind) == indexing.OuterIndexer - vind = Variable(('a', ), [0, 1]) + vind = Variable(("a",), [0, 1]) _, ind, _ = v._broadcast_indexes((vind, slice(0, 8, 2))) assert type(ind) == indexing.OuterIndexer - vind = Variable(('y', ), [0, 1]) + vind = Variable(("y",), [0, 1]) _, ind, _ = v._broadcast_indexes((vind, 3)) assert type(ind) == indexing.OuterIndexer - vind = Variable(('a', ), [0, 1]) + vind = Variable(("a",), [0, 1]) _, ind, _ = v._broadcast_indexes((vind, vind)) assert type(ind) == indexing.VectorizedIndexer - vind = Variable(('a', 'b'), [[0, 2], [1, 3]]) + vind = Variable(("a", "b"), [[0, 2], [1, 3]]) _, ind, _ = v._broadcast_indexes((vind, 3)) assert type(ind) == indexing.VectorizedIndexer def test_indexer_type(self): # GH:issue:1688. Wrong indexer type induces NotImplementedError data = np.random.random((10, 11)) - v = Variable(['x', 'y'], data) + v = Variable(["x", "y"], data) def assert_indexer_type(key, object_type): dims, index_tuple, new_order = v._broadcast_indexes(key) @@ -1072,34 +1094,37 @@ def assert_indexer_type(key, object_type): # should return OuterIndexer assert_indexer_type(([0, 1], 1), OuterIndexer) assert_indexer_type(([0, 1], [1, 2]), OuterIndexer) - assert_indexer_type((Variable(('x'), [0, 1]), 1), OuterIndexer) - assert_indexer_type((Variable(('x'), [0, 1]), slice(None, None)), - OuterIndexer) - assert_indexer_type((Variable(('x'), [0, 1]), Variable(('y'), [0, 1])), - OuterIndexer) + assert_indexer_type((Variable(("x"), [0, 1]), 1), OuterIndexer) + assert_indexer_type((Variable(("x"), [0, 1]), slice(None, None)), OuterIndexer) + assert_indexer_type( + (Variable(("x"), [0, 1]), Variable(("y"), [0, 1])), OuterIndexer + ) # should return VectorizedIndexer - assert_indexer_type((Variable(('y'), [0, 1]), [0, 1]), - VectorizedIndexer) - assert_indexer_type((Variable(('z'), [0, 1]), Variable(('z'), [0, 1])), - VectorizedIndexer) - assert_indexer_type((Variable(('a', 'b'), [[0, 1], [1, 2]]), - Variable(('a', 'b'), [[0, 1], [1, 2]])), - VectorizedIndexer) + assert_indexer_type((Variable(("y"), [0, 1]), [0, 1]), VectorizedIndexer) + assert_indexer_type( + (Variable(("z"), [0, 1]), Variable(("z"), [0, 1])), VectorizedIndexer + ) + assert_indexer_type( + ( + Variable(("a", "b"), [[0, 1], [1, 2]]), + Variable(("a", "b"), [[0, 1], [1, 2]]), + ), + VectorizedIndexer, + ) def test_items(self): data = np.random.random((10, 11)) - v = Variable(['x', 'y'], data) + v = Variable(["x", "y"], data) # test slicing assert_identical(v, v[:]) assert_identical(v, v[...]) - assert_identical(Variable(['y'], data[0]), v[0]) - assert_identical(Variable(['x'], data[:, 0]), v[:, 0]) - assert_identical(Variable(['x', 'y'], data[:3, :2]), - v[:3, :2]) + assert_identical(Variable(["y"], data[0]), v[0]) + assert_identical(Variable(["x"], data[:, 0]), v[:, 0]) + assert_identical(Variable(["x", "y"], data[:3, :2]), v[:3, :2]) # test array indexing - x = Variable(['x'], np.arange(10)) - y = Variable(['y'], np.arange(11)) + x = Variable(["x"], np.arange(10)) + y = Variable(["y"], np.arange(11)) assert_identical(v, v[x.values]) assert_identical(v, v[x]) assert_identical(v[:3], v[x < 3]) @@ -1109,8 +1134,8 @@ def test_items(self): assert_identical(v[:3, :2], v[range(3), range(2)]) # test iteration for n, item in enumerate(v): - assert_identical(Variable(['y'], data[n]), item) - with raises_regex(TypeError, 'iteration over a 0-d'): + assert_identical(Variable(["y"], data[n]), item) + with raises_regex(TypeError, "iteration over a 0-d"): iter(Variable([], 0)) # test setting v.values[:] = 0 @@ -1120,14 +1145,14 @@ def test_items(self): assert_array_equal(v.values, np.ones((10, 11))) def test_getitem_basic(self): - v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]]) + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) v_new = v[dict(x=0)] - assert v_new.dims == ('y', ) + assert v_new.dims == ("y",) assert_array_equal(v_new, v._data[0]) v_new = v[dict(x=0, y=slice(None))] - assert v_new.dims == ('y', ) + assert v_new.dims == ("y",) assert_array_equal(v_new, v._data[0]) v_new = v[dict(x=0, y=1)] @@ -1135,12 +1160,12 @@ def test_getitem_basic(self): assert_array_equal(v_new, v._data[0, 1]) v_new = v[dict(y=1)] - assert v_new.dims == ('x', ) + assert v_new.dims == ("x",) assert_array_equal(v_new, v._data[:, 1]) # tuple argument v_new = v[(slice(None), 1)] - assert v_new.dims == ('x', ) + assert v_new.dims == ("x",) assert_array_equal(v_new, v._data[:, 1]) # test that we obtain a modifiable view when taking a 0d slice @@ -1149,42 +1174,44 @@ def test_getitem_basic(self): assert_array_equal(v_new, v._data[0, 0]) def test_getitem_with_mask_2d_input(self): - v = Variable(('x', 'y'), [[0, 1, 2], [3, 4, 5]]) - assert_identical(v._getitem_with_mask(([-1, 0], [1, -1])), - Variable(('x', 'y'), [[np.nan, np.nan], [1, np.nan]])) + v = Variable(("x", "y"), [[0, 1, 2], [3, 4, 5]]) + assert_identical( + v._getitem_with_mask(([-1, 0], [1, -1])), + Variable(("x", "y"), [[np.nan, np.nan], [1, np.nan]]), + ) assert_identical(v._getitem_with_mask((slice(2), [0, 1, 2])), v) def test_isel(self): - v = Variable(['time', 'x'], self.d) + v = Variable(["time", "x"], self.d) assert_identical(v.isel(time=slice(None)), v) assert_identical(v.isel(time=0), v[0]) assert_identical(v.isel(time=slice(0, 3)), v[:3]) assert_identical(v.isel(x=0), v[:, 0]) - with raises_regex(ValueError, 'do not exist'): + with raises_regex(ValueError, "do not exist"): v.isel(not_a_dim=0) def test_index_0d_numpy_string(self): # regression test to verify our work around for indexing 0d strings - v = Variable([], np.string_('asdf')) + v = Variable([], np.string_("asdf")) assert_identical(v[()], v) - v = Variable([], np.unicode_('asdf')) + v = Variable([], np.unicode_("asdf")) assert_identical(v[()], v) def test_indexing_0d_unicode(self): # regression test for GH568 - actual = Variable(('x'), ['tmax'])[0][()] - expected = Variable((), 'tmax') + actual = Variable(("x"), ["tmax"])[0][()] + expected = Variable((), "tmax") assert_identical(actual, expected) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_shift(self, fill_value): - v = Variable('x', [1, 2, 3, 4, 5]) + v = Variable("x", [1, 2, 3, 4, 5]) assert_identical(v, v.shift(x=0)) assert v is not v.shift(x=0) - expected = Variable('x', [np.nan, np.nan, 1, 2, 3]) + expected = Variable("x", [np.nan, np.nan, 1, 2, 3]) assert_identical(expected, v.shift(x=2)) if fill_value == dtypes.NA: @@ -1194,242 +1221,235 @@ def test_shift(self, fill_value): else: fill_value_exp = fill_value - expected = Variable('x', [fill_value_exp, 1, 2, 3, 4]) + expected = Variable("x", [fill_value_exp, 1, 2, 3, 4]) assert_identical(expected, v.shift(x=1, fill_value=fill_value)) - expected = Variable('x', [2, 3, 4, 5, fill_value_exp]) + expected = Variable("x", [2, 3, 4, 5, fill_value_exp]) assert_identical(expected, v.shift(x=-1, fill_value=fill_value)) - expected = Variable('x', [fill_value_exp] * 5) + expected = Variable("x", [fill_value_exp] * 5) assert_identical(expected, v.shift(x=5, fill_value=fill_value)) assert_identical(expected, v.shift(x=6, fill_value=fill_value)) - with raises_regex(ValueError, 'dimension'): + with raises_regex(ValueError, "dimension"): v.shift(z=0) - v = Variable('x', [1, 2, 3, 4, 5], {'foo': 'bar'}) + v = Variable("x", [1, 2, 3, 4, 5], {"foo": "bar"}) assert_identical(v, v.shift(x=0)) - expected = Variable('x', [fill_value_exp, 1, 2, 3, 4], {'foo': 'bar'}) + expected = Variable("x", [fill_value_exp, 1, 2, 3, 4], {"foo": "bar"}) assert_identical(expected, v.shift(x=1, fill_value=fill_value)) def test_shift2d(self): - v = Variable(('x', 'y'), [[1, 2], [3, 4]]) - expected = Variable(('x', 'y'), [[np.nan, np.nan], [np.nan, 1]]) + v = Variable(("x", "y"), [[1, 2], [3, 4]]) + expected = Variable(("x", "y"), [[np.nan, np.nan], [np.nan, 1]]) assert_identical(expected, v.shift(x=1, y=1)) def test_roll(self): - v = Variable('x', [1, 2, 3, 4, 5]) + v = Variable("x", [1, 2, 3, 4, 5]) assert_identical(v, v.roll(x=0)) assert v is not v.roll(x=0) - expected = Variable('x', [5, 1, 2, 3, 4]) + expected = Variable("x", [5, 1, 2, 3, 4]) assert_identical(expected, v.roll(x=1)) assert_identical(expected, v.roll(x=-4)) assert_identical(expected, v.roll(x=6)) - expected = Variable('x', [4, 5, 1, 2, 3]) + expected = Variable("x", [4, 5, 1, 2, 3]) assert_identical(expected, v.roll(x=2)) assert_identical(expected, v.roll(x=-3)) - with raises_regex(ValueError, 'dimension'): + with raises_regex(ValueError, "dimension"): v.roll(z=0) def test_roll_consistency(self): - v = Variable(('x', 'y'), np.random.randn(5, 6)) + v = Variable(("x", "y"), np.random.randn(5, 6)) - for axis, dim in [(0, 'x'), (1, 'y')]: + for axis, dim in [(0, "x"), (1, "y")]: for shift in [-3, 0, 1, 7, 11]: expected = np.roll(v.values, shift, axis=axis) actual = v.roll(**{dim: shift}).values assert_array_equal(expected, actual) def test_transpose(self): - v = Variable(['time', 'x'], self.d) - v2 = Variable(['x', 'time'], self.d.T) + v = Variable(["time", "x"], self.d) + v2 = Variable(["x", "time"], self.d.T) assert_identical(v, v2.transpose()) assert_identical(v.transpose(), v.T) x = np.random.randn(2, 3, 4, 5) - w = Variable(['a', 'b', 'c', 'd'], x) - w2 = Variable(['d', 'b', 'c', 'a'], np.einsum('abcd->dbca', x)) + w = Variable(["a", "b", "c", "d"], x) + w2 = Variable(["d", "b", "c", "a"], np.einsum("abcd->dbca", x)) assert w2.shape == (5, 3, 4, 2) - assert_identical(w2, w.transpose('d', 'b', 'c', 'a')) - assert_identical(w, w2.transpose('a', 'b', 'c', 'd')) - w3 = Variable(['b', 'c', 'd', 'a'], np.einsum('abcd->bcda', x)) - assert_identical(w, w3.transpose('a', 'b', 'c', 'd')) + assert_identical(w2, w.transpose("d", "b", "c", "a")) + assert_identical(w, w2.transpose("a", "b", "c", "d")) + w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x)) + assert_identical(w, w3.transpose("a", "b", "c", "d")) def test_transpose_0d(self): for value in [ - 3.5, - ('a', 1), - np.datetime64('2000-01-01'), - np.timedelta64(1, 'h'), - None, - object(), + 3.5, + ("a", 1), + np.datetime64("2000-01-01"), + np.timedelta64(1, "h"), + None, + object(), ]: variable = Variable([], value) actual = variable.transpose() assert actual.identical(variable) def test_squeeze(self): - v = Variable(['x', 'y'], [[1]]) + v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) - assert_identical(Variable(['y'], [1]), v.squeeze('x')) - assert_identical(Variable(['y'], [1]), v.squeeze(['x'])) - assert_identical(Variable(['x'], [1]), v.squeeze('y')) - assert_identical(Variable([], 1), v.squeeze(['x', 'y'])) + assert_identical(Variable(["y"], [1]), v.squeeze("x")) + assert_identical(Variable(["y"], [1]), v.squeeze(["x"])) + assert_identical(Variable(["x"], [1]), v.squeeze("y")) + assert_identical(Variable([], 1), v.squeeze(["x", "y"])) - v = Variable(['x', 'y'], [[1, 2]]) - assert_identical(Variable(['y'], [1, 2]), v.squeeze()) - assert_identical(Variable(['y'], [1, 2]), v.squeeze('x')) - with raises_regex(ValueError, 'cannot select a dimension'): - v.squeeze('y') + v = Variable(["x", "y"], [[1, 2]]) + assert_identical(Variable(["y"], [1, 2]), v.squeeze()) + assert_identical(Variable(["y"], [1, 2]), v.squeeze("x")) + with raises_regex(ValueError, "cannot select a dimension"): + v.squeeze("y") def test_get_axis_num(self): - v = Variable(['x', 'y', 'z'], np.random.randn(2, 3, 4)) - assert v.get_axis_num('x') == 0 - assert v.get_axis_num(['x']) == (0,) - assert v.get_axis_num(['x', 'y']) == (0, 1) - assert v.get_axis_num(['z', 'y', 'x']) == (2, 1, 0) - with raises_regex(ValueError, 'not found in array dim'): - v.get_axis_num('foobar') + v = Variable(["x", "y", "z"], np.random.randn(2, 3, 4)) + assert v.get_axis_num("x") == 0 + assert v.get_axis_num(["x"]) == (0,) + assert v.get_axis_num(["x", "y"]) == (0, 1) + assert v.get_axis_num(["z", "y", "x"]) == (2, 1, 0) + with raises_regex(ValueError, "not found in array dim"): + v.get_axis_num("foobar") def test_set_dims(self): - v = Variable(['x'], [0, 1]) - actual = v.set_dims(['x', 'y']) - expected = Variable(['x', 'y'], [[0], [1]]) + v = Variable(["x"], [0, 1]) + actual = v.set_dims(["x", "y"]) + expected = Variable(["x", "y"], [[0], [1]]) assert_identical(actual, expected) - actual = v.set_dims(['y', 'x']) + actual = v.set_dims(["y", "x"]) assert_identical(actual, expected.T) - actual = v.set_dims(OrderedDict([('x', 2), ('y', 2)])) - expected = Variable(['x', 'y'], [[0, 0], [1, 1]]) + actual = v.set_dims(OrderedDict([("x", 2), ("y", 2)])) + expected = Variable(["x", "y"], [[0, 0], [1, 1]]) assert_identical(actual, expected) - v = Variable(['foo'], [0, 1]) - actual = v.set_dims('foo') + v = Variable(["foo"], [0, 1]) + actual = v.set_dims("foo") expected = v assert_identical(actual, expected) - with raises_regex(ValueError, 'must be a superset'): - v.set_dims(['z']) + with raises_regex(ValueError, "must be a superset"): + v.set_dims(["z"]) def test_set_dims_object_dtype(self): - v = Variable([], ('a', 1)) - actual = v.set_dims(('x',), (3,)) + v = Variable([], ("a", 1)) + actual = v.set_dims(("x",), (3,)) exp_values = np.empty((3,), dtype=object) for i in range(3): - exp_values[i] = ('a', 1) - expected = Variable(['x'], exp_values) + exp_values[i] = ("a", 1) + expected = Variable(["x"], exp_values) assert actual.identical(expected) def test_stack(self): - v = Variable(['x', 'y'], [[0, 1], [2, 3]], {'foo': 'bar'}) - actual = v.stack(z=('x', 'y')) - expected = Variable('z', [0, 1, 2, 3], v.attrs) + v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) + actual = v.stack(z=("x", "y")) + expected = Variable("z", [0, 1, 2, 3], v.attrs) assert_identical(actual, expected) - actual = v.stack(z=('x',)) - expected = Variable(('y', 'z'), v.data.T, v.attrs) + actual = v.stack(z=("x",)) + expected = Variable(("y", "z"), v.data.T, v.attrs) assert_identical(actual, expected) - actual = v.stack(z=(),) + actual = v.stack(z=()) assert_identical(actual, v) - actual = v.stack(X=('x',), Y=('y',)).transpose('X', 'Y') - expected = Variable(('X', 'Y'), v.data, v.attrs) + actual = v.stack(X=("x",), Y=("y",)).transpose("X", "Y") + expected = Variable(("X", "Y"), v.data, v.attrs) assert_identical(actual, expected) def test_stack_errors(self): - v = Variable(['x', 'y'], [[0, 1], [2, 3]], {'foo': 'bar'}) + v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) - with raises_regex(ValueError, 'invalid existing dim'): - v.stack(z=('x1',)) - with raises_regex(ValueError, 'cannot create a new dim'): - v.stack(x=('x',)) + with raises_regex(ValueError, "invalid existing dim"): + v.stack(z=("x1",)) + with raises_regex(ValueError, "cannot create a new dim"): + v.stack(x=("x",)) def test_unstack(self): - v = Variable('z', [0, 1, 2, 3], {'foo': 'bar'}) - actual = v.unstack(z=OrderedDict([('x', 2), ('y', 2)])) - expected = Variable(('x', 'y'), [[0, 1], [2, 3]], v.attrs) + v = Variable("z", [0, 1, 2, 3], {"foo": "bar"}) + actual = v.unstack(z=OrderedDict([("x", 2), ("y", 2)])) + expected = Variable(("x", "y"), [[0, 1], [2, 3]], v.attrs) assert_identical(actual, expected) - actual = v.unstack(z=OrderedDict([('x', 4), ('y', 1)])) - expected = Variable(('x', 'y'), [[0], [1], [2], [3]], v.attrs) + actual = v.unstack(z=OrderedDict([("x", 4), ("y", 1)])) + expected = Variable(("x", "y"), [[0], [1], [2], [3]], v.attrs) assert_identical(actual, expected) - actual = v.unstack(z=OrderedDict([('x', 4)])) - expected = Variable('x', [0, 1, 2, 3], v.attrs) + actual = v.unstack(z=OrderedDict([("x", 4)])) + expected = Variable("x", [0, 1, 2, 3], v.attrs) assert_identical(actual, expected) def test_unstack_errors(self): - v = Variable('z', [0, 1, 2, 3]) - with raises_regex(ValueError, 'invalid existing dim'): - v.unstack(foo={'x': 4}) - with raises_regex(ValueError, 'cannot create a new dim'): - v.stack(z=('z',)) - with raises_regex(ValueError, 'the product of the new dim'): - v.unstack(z={'x': 5}) + v = Variable("z", [0, 1, 2, 3]) + with raises_regex(ValueError, "invalid existing dim"): + v.unstack(foo={"x": 4}) + with raises_regex(ValueError, "cannot create a new dim"): + v.stack(z=("z",)) + with raises_regex(ValueError, "the product of the new dim"): + v.unstack(z={"x": 5}) def test_unstack_2d(self): - v = Variable(['x', 'y'], [[0, 1], [2, 3]]) - actual = v.unstack(y={'z': 2}) - expected = Variable(['x', 'z'], v.data) + v = Variable(["x", "y"], [[0, 1], [2, 3]]) + actual = v.unstack(y={"z": 2}) + expected = Variable(["x", "z"], v.data) assert_identical(actual, expected) - actual = v.unstack(x={'z': 2}) - expected = Variable(['y', 'z'], v.data.T) + actual = v.unstack(x={"z": 2}) + expected = Variable(["y", "z"], v.data.T) assert_identical(actual, expected) def test_stack_unstack_consistency(self): - v = Variable(['x', 'y'], [[0, 1], [2, 3]]) - actual = (v.stack(z=('x', 'y')) - .unstack(z=OrderedDict([('x', 2), ('y', 2)]))) + v = Variable(["x", "y"], [[0, 1], [2, 3]]) + actual = v.stack(z=("x", "y")).unstack(z=OrderedDict([("x", 2), ("y", 2)])) assert_identical(actual, v) def test_broadcasting_math(self): x = np.random.randn(2, 3) - v = Variable(['a', 'b'], x) + v = Variable(["a", "b"], x) # 1d to 2d broadcasting + assert_identical(v * v, Variable(["a", "b"], np.einsum("ab,ab->ab", x, x))) + assert_identical(v * v[0], Variable(["a", "b"], np.einsum("ab,b->ab", x, x[0]))) + assert_identical(v[0] * v, Variable(["b", "a"], np.einsum("b,ab->ba", x[0], x))) assert_identical( - v * v, - Variable(['a', 'b'], np.einsum('ab,ab->ab', x, x))) - assert_identical( - v * v[0], - Variable(['a', 'b'], np.einsum('ab,b->ab', x, x[0]))) - assert_identical( - v[0] * v, - Variable(['b', 'a'], np.einsum('b,ab->ba', x[0], x))) - assert_identical( - v[0] * v[:, 0], - Variable(['b', 'a'], np.einsum('b,a->ba', x[0], x[:, 0]))) + v[0] * v[:, 0], Variable(["b", "a"], np.einsum("b,a->ba", x[0], x[:, 0])) + ) # higher dim broadcasting y = np.random.randn(3, 4, 5) - w = Variable(['b', 'c', 'd'], y) + w = Variable(["b", "c", "d"], y) assert_identical( - v * w, Variable(['a', 'b', 'c', 'd'], - np.einsum('ab,bcd->abcd', x, y))) + v * w, Variable(["a", "b", "c", "d"], np.einsum("ab,bcd->abcd", x, y)) + ) assert_identical( - w * v, Variable(['b', 'c', 'd', 'a'], - np.einsum('bcd,ab->bcda', y, x))) + w * v, Variable(["b", "c", "d", "a"], np.einsum("bcd,ab->bcda", y, x)) + ) assert_identical( - v * w[0], Variable(['a', 'b', 'c', 'd'], - np.einsum('ab,cd->abcd', x, y[0]))) + v * w[0], Variable(["a", "b", "c", "d"], np.einsum("ab,cd->abcd", x, y[0])) + ) def test_broadcasting_failures(self): - a = Variable(['x'], np.arange(10)) - b = Variable(['x'], np.arange(5)) - c = Variable(['x', 'x'], np.arange(100).reshape(10, 10)) - with raises_regex(ValueError, 'mismatched lengths'): + a = Variable(["x"], np.arange(10)) + b = Variable(["x"], np.arange(5)) + c = Variable(["x", "x"], np.arange(100).reshape(10, 10)) + with raises_regex(ValueError, "mismatched lengths"): a + b - with raises_regex(ValueError, 'duplicate dimensions'): + with raises_regex(ValueError, "duplicate dimensions"): a + c def test_inplace_math(self): x = np.arange(5) - v = Variable(['x'], x) + v = Variable(["x"], x) v2 = v v2 += 1 assert v is v2 @@ -1437,133 +1457,136 @@ def test_inplace_math(self): assert source_ndarray(v.values) is x assert_array_equal(v.values, np.arange(5) + 1) - with raises_regex(ValueError, 'dimensions cannot change'): - v += Variable('y', np.arange(5)) + with raises_regex(ValueError, "dimensions cannot change"): + v += Variable("y", np.arange(5)) def test_reduce(self): - v = Variable(['x', 'y'], self.d, {'ignored': 'attributes'}) - assert_identical(v.reduce(np.std, 'x'), - Variable(['y'], self.d.std(axis=0))) - assert_identical(v.reduce(np.std, axis=0), - v.reduce(np.std, dim='x')) - assert_identical(v.reduce(np.std, ['y', 'x']), - Variable([], self.d.std(axis=(0, 1)))) - assert_identical(v.reduce(np.std), - Variable([], self.d.std())) + v = Variable(["x", "y"], self.d, {"ignored": "attributes"}) + assert_identical(v.reduce(np.std, "x"), Variable(["y"], self.d.std(axis=0))) + assert_identical(v.reduce(np.std, axis=0), v.reduce(np.std, dim="x")) + assert_identical( + v.reduce(np.std, ["y", "x"]), Variable([], self.d.std(axis=(0, 1))) + ) + assert_identical(v.reduce(np.std), Variable([], self.d.std())) assert_identical( - v.reduce(np.mean, 'x').reduce(np.std, 'y'), - Variable([], self.d.mean(axis=0).std())) - assert_allclose(v.mean('x'), v.reduce(np.mean, 'x')) + v.reduce(np.mean, "x").reduce(np.std, "y"), + Variable([], self.d.mean(axis=0).std()), + ) + assert_allclose(v.mean("x"), v.reduce(np.mean, "x")) - with raises_regex(ValueError, 'cannot supply both'): - v.mean(dim='x', axis=0) + with raises_regex(ValueError, "cannot supply both"): + v.mean(dim="x", axis=0) def test_quantile(self): - v = Variable(['x', 'y'], self.d) + v = Variable(["x", "y"], self.d) for q in [0.25, [0.50], [0.25, 0.75]]: - for axis, dim in zip([None, 0, [0], [0, 1]], - [None, 'x', ['x'], ['x', 'y']]): + for axis, dim in zip( + [None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]] + ): actual = v.quantile(q, dim=dim) - expected = np.nanpercentile(self.d, np.array(q) * 100, - axis=axis) + expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis) np.testing.assert_allclose(actual.values, expected) @requires_dask def test_quantile_dask_raises(self): # regression for GH1524 - v = Variable(['x', 'y'], self.d).chunk(2) + v = Variable(["x", "y"], self.d).chunk(2) - with raises_regex(TypeError, 'arrays stored as dask'): - v.quantile(0.5, dim='x') + with raises_regex(TypeError, "arrays stored as dask"): + v.quantile(0.5, dim="x") @requires_dask @requires_bottleneck def test_rank_dask_raises(self): - v = Variable(['x'], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2) - with raises_regex(TypeError, 'arrays stored as dask'): - v.rank('x') + v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2) + with raises_regex(TypeError, "arrays stored as dask"): + v.rank("x") @requires_bottleneck def test_rank(self): import bottleneck as bn + # floats - v = Variable(['x', 'y'], [[3, 4, np.nan, 1]]) + v = Variable(["x", "y"], [[3, 4, np.nan, 1]]) expect_0 = bn.nanrankdata(v.data, axis=0) expect_1 = bn.nanrankdata(v.data, axis=1) - np.testing.assert_allclose(v.rank('x').values, expect_0) - np.testing.assert_allclose(v.rank('y').values, expect_1) + np.testing.assert_allclose(v.rank("x").values, expect_0) + np.testing.assert_allclose(v.rank("y").values, expect_1) # int - v = Variable(['x'], [3, 2, 1]) + v = Variable(["x"], [3, 2, 1]) expect = bn.rankdata(v.data, axis=0) - np.testing.assert_allclose(v.rank('x').values, expect) + np.testing.assert_allclose(v.rank("x").values, expect) # str - v = Variable(['x'], ['c', 'b', 'a']) + v = Variable(["x"], ["c", "b", "a"]) expect = bn.rankdata(v.data, axis=0) - np.testing.assert_allclose(v.rank('x').values, expect) + np.testing.assert_allclose(v.rank("x").values, expect) # pct - v = Variable(['x'], [3.0, 1.0, np.nan, 2.0, 4.0]) - v_expect = Variable(['x'], [0.75, 0.25, np.nan, 0.5, 1.0]) - assert_equal(v.rank('x', pct=True), v_expect) + v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]) + v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0]) + assert_equal(v.rank("x", pct=True), v_expect) # invalid dim - with raises_regex(ValueError, 'not found'): - v.rank('y') + with raises_regex(ValueError, "not found"): + v.rank("y") def test_big_endian_reduce(self): # regression test for GH489 - data = np.ones(5, dtype='>f4') - v = Variable(['x'], data) + data = np.ones(5, dtype=">f4") + v = Variable(["x"], data) expected = Variable([], 5) assert_identical(expected, v.sum()) def test_reduce_funcs(self): - v = Variable('x', np.array([1, np.nan, 2, 3])) + v = Variable("x", np.array([1, np.nan, 2, 3])) assert_identical(v.mean(), Variable([], 2)) assert_identical(v.mean(skipna=True), Variable([], 2)) assert_identical(v.mean(skipna=False), Variable([], np.nan)) assert_identical(np.mean(v), Variable([], 2)) assert_identical(v.prod(), Variable([], 6)) - assert_identical(v.cumsum(axis=0), - Variable('x', np.array([1, 1, 3, 6]))) - assert_identical(v.cumprod(axis=0), - Variable('x', np.array([1, 1, 2, 6]))) + assert_identical(v.cumsum(axis=0), Variable("x", np.array([1, 1, 3, 6]))) + assert_identical(v.cumprod(axis=0), Variable("x", np.array([1, 1, 2, 6]))) assert_identical(v.var(), Variable([], 2.0 / 3)) assert_identical(v.median(), Variable([], 2)) - v = Variable('x', [True, False, False]) + v = Variable("x", [True, False, False]) assert_identical(v.any(), Variable([], True)) - assert_identical(v.all(dim='x'), Variable([], False)) + assert_identical(v.all(dim="x"), Variable([], False)) - v = Variable('t', pd.date_range('2000-01-01', periods=3)) + v = Variable("t", pd.date_range("2000-01-01", periods=3)) assert v.argmax(skipna=True) == 2 - assert_identical( - v.max(), Variable([], pd.Timestamp('2000-01-03'))) + assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03"))) def test_reduce_keepdims(self): - v = Variable(['x', 'y'], self.d) - - assert_identical(v.mean(keepdims=True), - Variable(v.dims, np.mean(self.d, keepdims=True))) - assert_identical(v.mean(dim='x', keepdims=True), - Variable(v.dims, np.mean(self.d, axis=0, - keepdims=True))) - assert_identical(v.mean(dim='y', keepdims=True), - Variable(v.dims, np.mean(self.d, axis=1, - keepdims=True))) - assert_identical(v.mean(dim=['y', 'x'], keepdims=True), - Variable(v.dims, np.mean(self.d, axis=(1, 0), - keepdims=True))) + v = Variable(["x", "y"], self.d) + + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) v = Variable([], 1.0) - assert_identical(v.mean(keepdims=True), - Variable([], np.mean(v.data, keepdims=True))) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) @requires_dask def test_reduce_keepdims_dask(self): import dask.array - v = Variable(['x', 'y'], self.d).chunk() + + v = Variable(["x", "y"], self.d).chunk() actual = v.mean(keepdims=True) assert isinstance(actual.data, dask.array.Array) @@ -1571,16 +1594,16 @@ def test_reduce_keepdims_dask(self): expected = Variable(v.dims, np.mean(self.d, keepdims=True)) assert_identical(actual, expected) - actual = v.mean(dim='y', keepdims=True) + actual = v.mean(dim="y", keepdims=True) assert isinstance(actual.data, dask.array.Array) expected = Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)) assert_identical(actual, expected) def test_reduce_keep_attrs(self): - _attrs = {'units': 'test', 'long_name': 'testing'} + _attrs = {"units": "test", "long_name": "testing"} - v = Variable(['x', 'y'], self.d, _attrs) + v = Variable(["x", "y"], self.d, _attrs) # Test dropped attrs vm = v.mean() @@ -1593,11 +1616,11 @@ def test_reduce_keep_attrs(self): assert vm.attrs == _attrs def test_binary_ops_keep_attrs(self): - _attrs = {'units': 'test', 'long_name': 'testing'} - a = Variable(['x', 'y'], np.random.randn(3, 3), _attrs) - b = Variable(['x', 'y'], np.random.randn(3, 3), _attrs) + _attrs = {"units": "test", "long_name": "testing"} + a = Variable(["x", "y"], np.random.randn(3, 3), _attrs) + b = Variable(["x", "y"], np.random.randn(3, 3), _attrs) # Test dropped attrs - d = a - b # just one operation + d = a - b # just one operation assert d.attrs == OrderedDict() # Test kept attrs with set_options(keep_attrs=True): @@ -1606,36 +1629,36 @@ def test_binary_ops_keep_attrs(self): def test_count(self): expected = Variable([], 3) - actual = Variable(['x'], [1, 2, 3, np.nan]).count() + actual = Variable(["x"], [1, 2, 3, np.nan]).count() assert_identical(expected, actual) - v = Variable(['x'], np.array(['1', '2', '3', np.nan], dtype=object)) + v = Variable(["x"], np.array(["1", "2", "3", np.nan], dtype=object)) actual = v.count() assert_identical(expected, actual) - actual = Variable(['x'], [True, False, True]).count() + actual = Variable(["x"], [True, False, True]).count() assert_identical(expected, actual) assert actual.dtype == int - expected = Variable(['x'], [2, 3]) - actual = Variable(['x', 'y'], [[1, 0, np.nan], [1, 1, 1]]).count('y') + expected = Variable(["x"], [2, 3]) + actual = Variable(["x", "y"], [[1, 0, np.nan], [1, 1, 1]]).count("y") assert_identical(expected, actual) def test_setitem(self): - v = Variable(['x', 'y'], [[0, 3, 2], [3, 4, 5]]) + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) v[0, 1] = 1 assert v[0, 1] == 1 - v = Variable(['x', 'y'], [[0, 3, 2], [3, 4, 5]]) + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) v[dict(x=[0, 1])] = 1 assert_array_equal(v[[0, 1]], np.ones_like(v[[0, 1]])) # boolean indexing - v = Variable(['x', 'y'], [[0, 3, 2], [3, 4, 5]]) + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) v[dict(x=[True, False])] = 1 assert_array_equal(v[0], np.ones_like(v[0])) - v = Variable(['x', 'y'], [[0, 3, 2], [3, 4, 5]]) + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) v[dict(x=[True, False], y=[False, True, False])] = 1 assert v[0, 1] == 1 @@ -1644,126 +1667,142 @@ def test_setitem_fancy(self): def assert_assigned_2d(array, key_x, key_y, values): expected = array.copy() expected[key_x, key_y] = values - v = Variable(['x', 'y'], array) + v = Variable(["x", "y"], array) v[dict(x=key_x, y=key_y)] = values assert_array_equal(expected, v) # 1d vectorized indexing - assert_assigned_2d(np.random.randn(4, 3), - key_x=Variable(['a'], [0, 1]), - key_y=Variable(['a'], [0, 1]), - values=0) - assert_assigned_2d(np.random.randn(4, 3), - key_x=Variable(['a'], [0, 1]), - key_y=Variable(['a'], [0, 1]), - values=Variable((), 0)) - assert_assigned_2d(np.random.randn(4, 3), - key_x=Variable(['a'], [0, 1]), - key_y=Variable(['a'], [0, 1]), - values=Variable(('a'), [3, 2])) - assert_assigned_2d(np.random.randn(4, 3), - key_x=slice(None), - key_y=Variable(['a'], [0, 1]), - values=Variable(('a'), [3, 2])) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a"], [0, 1]), + key_y=Variable(["a"], [0, 1]), + values=0, + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a"], [0, 1]), + key_y=Variable(["a"], [0, 1]), + values=Variable((), 0), + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a"], [0, 1]), + key_y=Variable(["a"], [0, 1]), + values=Variable(("a"), [3, 2]), + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=slice(None), + key_y=Variable(["a"], [0, 1]), + values=Variable(("a"), [3, 2]), + ) # 2d-vectorized indexing - assert_assigned_2d(np.random.randn(4, 3), - key_x=Variable(['a', 'b'], [[0, 1]]), - key_y=Variable(['a', 'b'], [[1, 0]]), - values=0) - assert_assigned_2d(np.random.randn(4, 3), - key_x=Variable(['a', 'b'], [[0, 1]]), - key_y=Variable(['a', 'b'], [[1, 0]]), - values=[0]) - assert_assigned_2d(np.random.randn(5, 4), - key_x=Variable(['a', 'b'], [[0, 1], [2, 3]]), - key_y=Variable(['a', 'b'], [[1, 0], [3, 3]]), - values=[2, 3]) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a", "b"], [[0, 1]]), + key_y=Variable(["a", "b"], [[1, 0]]), + values=0, + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a", "b"], [[0, 1]]), + key_y=Variable(["a", "b"], [[1, 0]]), + values=[0], + ) + assert_assigned_2d( + np.random.randn(5, 4), + key_x=Variable(["a", "b"], [[0, 1], [2, 3]]), + key_y=Variable(["a", "b"], [[1, 0], [3, 3]]), + values=[2, 3], + ) # vindex with slice - v = Variable(['x', 'y', 'z'], np.ones((4, 3, 2))) - ind = Variable(['a'], [0, 1]) + v = Variable(["x", "y", "z"], np.ones((4, 3, 2))) + ind = Variable(["a"], [0, 1]) v[dict(x=ind, z=ind)] = 0 - expected = Variable(['x', 'y', 'z'], np.ones((4, 3, 2))) + expected = Variable(["x", "y", "z"], np.ones((4, 3, 2))) expected[0, :, 0] = 0 expected[1, :, 1] = 0 assert_identical(expected, v) # dimension broadcast - v = Variable(['x', 'y'], np.ones((3, 2))) - ind = Variable(['a', 'b'], [[0, 1]]) + v = Variable(["x", "y"], np.ones((3, 2))) + ind = Variable(["a", "b"], [[0, 1]]) v[ind, :] = 0 - expected = Variable(['x', 'y'], [[0, 0], [0, 0], [1, 1]]) + expected = Variable(["x", "y"], [[0, 0], [0, 0], [1, 1]]) assert_identical(expected, v) with raises_regex(ValueError, "shape mismatch"): v[ind, ind] = np.zeros((1, 2, 1)) - v = Variable(['x', 'y'], [[0, 3, 2], [3, 4, 5]]) - ind = Variable(['a'], [0, 1]) - v[dict(x=ind)] = Variable(['a', 'y'], np.ones((2, 3), dtype=int) * 10) + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) + ind = Variable(["a"], [0, 1]) + v[dict(x=ind)] = Variable(["a", "y"], np.ones((2, 3), dtype=int) * 10) assert_array_equal(v[0], np.ones_like(v[0]) * 10) assert_array_equal(v[1], np.ones_like(v[1]) * 10) - assert v.dims == ('x', 'y') # dimension should not change + assert v.dims == ("x", "y") # dimension should not change # increment - v = Variable(['x', 'y'], np.arange(6).reshape(3, 2)) - ind = Variable(['a'], [0, 1]) + v = Variable(["x", "y"], np.arange(6).reshape(3, 2)) + ind = Variable(["a"], [0, 1]) v[dict(x=ind)] += 1 - expected = Variable(['x', 'y'], [[1, 2], [3, 4], [4, 5]]) + expected = Variable(["x", "y"], [[1, 2], [3, 4], [4, 5]]) assert_identical(v, expected) - ind = Variable(['a'], [0, 0]) + ind = Variable(["a"], [0, 0]) v[dict(x=ind)] += 1 - expected = Variable(['x', 'y'], [[2, 3], [3, 4], [4, 5]]) + expected = Variable(["x", "y"], [[2, 3], [3, 4], [4, 5]]) assert_identical(v, expected) def test_coarsen(self): - v = self.cls(['x'], [0, 1, 2, 3, 4]) - actual = v.coarsen({'x': 2}, boundary='pad', func='mean') - expected = self.cls(['x'], [0.5, 2.5, 4]) + v = self.cls(["x"], [0, 1, 2, 3, 4]) + actual = v.coarsen({"x": 2}, boundary="pad", func="mean") + expected = self.cls(["x"], [0.5, 2.5, 4]) assert_identical(actual, expected) - actual = v.coarsen({'x': 2}, func='mean', boundary='pad', - side='right') - expected = self.cls(['x'], [0, 1.5, 3.5]) + actual = v.coarsen({"x": 2}, func="mean", boundary="pad", side="right") + expected = self.cls(["x"], [0, 1.5, 3.5]) assert_identical(actual, expected) - actual = v.coarsen({'x': 2}, func=np.mean, side='right', - boundary='trim') - expected = self.cls(['x'], [1.5, 3.5]) + actual = v.coarsen({"x": 2}, func=np.mean, side="right", boundary="trim") + expected = self.cls(["x"], [1.5, 3.5]) assert_identical(actual, expected) # working test - v = self.cls(['x', 'y', 'z'], - np.arange(40 * 30 * 2).reshape(40, 30, 2)) + v = self.cls(["x", "y", "z"], np.arange(40 * 30 * 2).reshape(40, 30, 2)) for windows, func, side, boundary in [ - ({'x': 2}, np.mean, 'left', 'trim'), - ({'x': 2}, np.median, {'x': 'left'}, 'pad'), - ({'x': 2, 'y': 3}, np.max, 'left', {'x': 'pad', 'y': 'trim'})]: + ({"x": 2}, np.mean, "left", "trim"), + ({"x": 2}, np.median, {"x": "left"}, "pad"), + ({"x": 2, "y": 3}, np.max, "left", {"x": "pad", "y": "trim"}), + ]: v.coarsen(windows, func, boundary, side) def test_coarsen_2d(self): # 2d-mean should be the same with the successive 1d-mean - v = self.cls(['x', 'y'], np.arange(6 * 12).reshape(6, 12)) - actual = v.coarsen({'x': 3, 'y': 4}, func='mean') - expected = v.coarsen({'x': 3}, func='mean').coarsen( - {'y': 4}, func='mean') + v = self.cls(["x", "y"], np.arange(6 * 12).reshape(6, 12)) + actual = v.coarsen({"x": 3, "y": 4}, func="mean") + expected = v.coarsen({"x": 3}, func="mean").coarsen({"y": 4}, func="mean") assert_equal(actual, expected) - v = self.cls(['x', 'y'], np.arange(7 * 12).reshape(7, 12)) - actual = v.coarsen({'x': 3, 'y': 4}, func='mean', boundary='trim') - expected = v.coarsen({'x': 3}, func='mean', boundary='trim').coarsen( - {'y': 4}, func='mean', boundary='trim') + v = self.cls(["x", "y"], np.arange(7 * 12).reshape(7, 12)) + actual = v.coarsen({"x": 3, "y": 4}, func="mean", boundary="trim") + expected = v.coarsen({"x": 3}, func="mean", boundary="trim").coarsen( + {"y": 4}, func="mean", boundary="trim" + ) assert_equal(actual, expected) # if there is nan, the two should be different - v = self.cls(['x', 'y'], 1.0 * np.arange(6 * 12).reshape(6, 12)) + v = self.cls(["x", "y"], 1.0 * np.arange(6 * 12).reshape(6, 12)) v[2, 4] = np.nan v[3, 5] = np.nan - actual = v.coarsen({'x': 3, 'y': 4}, func='mean', boundary='trim') - expected = v.coarsen({'x': 3}, func='sum', boundary='trim').coarsen( - {'y': 4}, func='sum', boundary='trim') / 12 + actual = v.coarsen({"x": 3, "y": 4}, func="mean", boundary="trim") + expected = ( + v.coarsen({"x": 3}, func="sum", boundary="trim").coarsen( + {"y": 4}, func="sum", boundary="trim" + ) + / 12 + ) assert not actual.equals(expected) # adjusting the nan count expected[0, 1] *= 12 / 11 @@ -1800,102 +1839,106 @@ def test_getitem_1d_fancy(self): def test_equals_all_dtypes(self): import dask - if '0.18.2' <= LooseVersion(dask.__version__) < '0.19.1': - pytest.xfail('https://github.com/pydata/xarray/issues/2318') + + if "0.18.2" <= LooseVersion(dask.__version__) < "0.19.1": + pytest.xfail("https://github.com/pydata/xarray/issues/2318") super().test_equals_all_dtypes() def test_getitem_with_mask_nd_indexer(self): import dask.array as da - v = Variable(['x'], da.arange(3, chunks=3)) - indexer = Variable(('x', 'y'), [[0, -1], [-1, 2]]) - assert_identical(v._getitem_with_mask(indexer, fill_value=-1), - self.cls(('x', 'y'), [[0, -1], [-1, 2]])) + + v = Variable(["x"], da.arange(3, chunks=3)) + indexer = Variable(("x", "y"), [[0, -1], [-1, 2]]) + assert_identical( + v._getitem_with_mask(indexer, fill_value=-1), + self.cls(("x", "y"), [[0, -1], [-1, 2]]), + ) class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable) def test_init(self): - with raises_regex(ValueError, 'must be 1-dimensional'): + with raises_regex(ValueError, "must be 1-dimensional"): IndexVariable((), 0) def test_to_index(self): data = 0.5 * np.arange(10) - v = IndexVariable(['time'], data, {'foo': 'bar'}) - assert pd.Index(data, name='time').identical(v.to_index()) + v = IndexVariable(["time"], data, {"foo": "bar"}) + assert pd.Index(data, name="time").identical(v.to_index()) def test_multiindex_default_level_names(self): - midx = pd.MultiIndex.from_product([['a', 'b'], [1, 2]]) - v = IndexVariable(['x'], midx, {'foo': 'bar'}) - assert v.to_index().names == ('x_level_0', 'x_level_1') + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]]) + v = IndexVariable(["x"], midx, {"foo": "bar"}) + assert v.to_index().names == ("x_level_0", "x_level_1") def test_data(self): - x = IndexVariable('x', np.arange(3.0)) + x = IndexVariable("x", np.arange(3.0)) assert isinstance(x._data, PandasIndexAdapter) assert isinstance(x.data, np.ndarray) assert float == x.dtype assert_array_equal(np.arange(3), x) assert float == x.values.dtype - with raises_regex(TypeError, 'cannot be modified'): + with raises_regex(TypeError, "cannot be modified"): x[:] = 0 def test_name(self): - coord = IndexVariable('x', [10.0]) - assert coord.name == 'x' + coord = IndexVariable("x", [10.0]) + assert coord.name == "x" with pytest.raises(AttributeError): - coord.name = 'y' + coord.name = "y" def test_level_names(self): - midx = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=['level_1', 'level_2']) - x = IndexVariable('x', midx) + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["level_1", "level_2"] + ) + x = IndexVariable("x", midx) assert x.level_names == midx.names - assert IndexVariable('y', [10.0]).level_names is None + assert IndexVariable("y", [10.0]).level_names is None def test_get_level_variable(self): - midx = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], - names=['level_1', 'level_2']) - x = IndexVariable('x', midx) - level_1 = IndexVariable('x', midx.get_level_values('level_1')) - assert_identical(x.get_level_variable('level_1'), level_1) + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["level_1", "level_2"] + ) + x = IndexVariable("x", midx) + level_1 = IndexVariable("x", midx.get_level_values("level_1")) + assert_identical(x.get_level_variable("level_1"), level_1) - with raises_regex(ValueError, 'has no MultiIndex'): - IndexVariable('y', [10.0]).get_level_variable('level') + with raises_regex(ValueError, "has no MultiIndex"): + IndexVariable("y", [10.0]).get_level_variable("level") def test_concat_periods(self): - periods = pd.period_range('2000-01-01', periods=10) - coords = [IndexVariable('t', periods[:5]), - IndexVariable('t', periods[5:])] - expected = IndexVariable('t', periods) - actual = IndexVariable.concat(coords, dim='t') + periods = pd.period_range("2000-01-01", periods=10) + coords = [IndexVariable("t", periods[:5]), IndexVariable("t", periods[5:])] + expected = IndexVariable("t", periods) + actual = IndexVariable.concat(coords, dim="t") assert actual.identical(expected) assert isinstance(actual.to_index(), pd.PeriodIndex) positions = [list(range(5)), list(range(5, 10))] - actual = IndexVariable.concat(coords, dim='t', positions=positions) + actual = IndexVariable.concat(coords, dim="t", positions=positions) assert actual.identical(expected) assert isinstance(actual.to_index(), pd.PeriodIndex) def test_concat_multiindex(self): - idx = pd.MultiIndex.from_product([[0, 1, 2], ['a', 'b']]) - coords = [IndexVariable('x', idx[:2]), IndexVariable('x', idx[2:])] - expected = IndexVariable('x', idx) - actual = IndexVariable.concat(coords, dim='x') + idx = pd.MultiIndex.from_product([[0, 1, 2], ["a", "b"]]) + coords = [IndexVariable("x", idx[:2]), IndexVariable("x", idx[2:])] + expected = IndexVariable("x", idx) + actual = IndexVariable.concat(coords, dim="x") assert actual.identical(expected) assert isinstance(actual.to_index(), pd.MultiIndex) def test_coordinate_alias(self): - with pytest.warns(Warning, match='deprecated'): - x = Coordinate('x', [1, 2, 3]) + with pytest.warns(Warning, match="deprecated"): + x = Coordinate("x", [1, 2, 3]) assert isinstance(x, IndexVariable) def test_datetime64(self): # GH:1932 Make sure indexing keeps precision - t = np.array([1518418799999986560, 1518418799999996560], - dtype='datetime64[ns]') - v = IndexVariable('t', t) + t = np.array([1518418799999986560, 1518418799999996560], dtype="datetime64[ns]") + v = IndexVariable("t", t) assert v[0].data == t[0] # These tests make use of multi-dimensional variables, which are not valid @@ -1933,12 +1976,13 @@ class TestAsCompatibleData: def test_unchanged_types(self): types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray) for t in types: - for data in [np.arange(3), - pd.date_range('2000-01-01', periods=3), - pd.date_range('2000-01-01', periods=3).values]: + for data in [ + np.arange(3), + pd.date_range("2000-01-01", periods=3), + pd.date_range("2000-01-01", periods=3).values, + ]: x = t(data) - assert source_ndarray(x) is \ - source_ndarray(as_compatible_data(x)) + assert source_ndarray(x) is source_ndarray(as_compatible_data(x)) def test_converted_types(self): for input_array in [[[0, 1, 2]], pd.DataFrame([[0, 1, 2]])]: @@ -1962,35 +2006,36 @@ def test_masked_array(self): assert np.dtype(float) == actual.dtype def test_datetime(self): - expected = np.datetime64('2000-01-01') + expected = np.datetime64("2000-01-01") actual = as_compatible_data(expected) assert expected == actual assert np.ndarray == type(actual) - assert np.dtype('datetime64[ns]') == actual.dtype + assert np.dtype("datetime64[ns]") == actual.dtype - expected = np.array([np.datetime64('2000-01-01')]) + expected = np.array([np.datetime64("2000-01-01")]) actual = as_compatible_data(expected) assert np.asarray(expected) == actual assert np.ndarray == type(actual) - assert np.dtype('datetime64[ns]') == actual.dtype + assert np.dtype("datetime64[ns]") == actual.dtype - expected = np.array([np.datetime64('2000-01-01', 'ns')]) + expected = np.array([np.datetime64("2000-01-01", "ns")]) actual = as_compatible_data(expected) assert np.asarray(expected) == actual assert np.ndarray == type(actual) - assert np.dtype('datetime64[ns]') == actual.dtype + assert np.dtype("datetime64[ns]") == actual.dtype assert expected is source_ndarray(np.asarray(actual)) - expected = np.datetime64('2000-01-01', 'ns') + expected = np.datetime64("2000-01-01", "ns") actual = as_compatible_data(datetime(2000, 1, 1)) assert np.asarray(expected) == actual assert np.ndarray == type(actual) - assert np.dtype('datetime64[ns]') == actual.dtype + assert np.dtype("datetime64[ns]") == actual.dtype def test_full_like(self): # For more thorough tests, see test_variable.py - orig = Variable(dims=('x', 'y'), data=[[1.5, 2.0], [3.1, 4.3]], - attrs={'foo': 'bar'}) + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ) expect = orig.copy(deep=True) expect.values = [[2.0, 2.0], [2.0, 2.0]] @@ -2003,8 +2048,9 @@ def test_full_like(self): @requires_dask def test_full_like_dask(self): - orig = Variable(dims=('x', 'y'), data=[[1.5, 2.0], [3.1, 4.3]], - attrs={'foo': 'bar'}).chunk(((1, 1), (2,))) + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ).chunk(((1, 1), (2,))) def check(actual, expect_dtype, expect_values): assert actual.dtype == expect_dtype @@ -2014,11 +2060,13 @@ def check(actual, expect_dtype, expect_values): assert actual.chunks == orig.chunks assert_array_equal(actual.values, expect_values) - check(full_like(orig, 2), - orig.dtype, np.full_like(orig.values, 2)) + check(full_like(orig, 2), orig.dtype, np.full_like(orig.values, 2)) # override dtype - check(full_like(orig, True, dtype=bool), - bool, np.full_like(orig.values, True, dtype=bool)) + check( + full_like(orig, True, dtype=bool), + bool, + np.full_like(orig.values, True, dtype=bool), + ) # Check that there's no array stored inside dask # (e.g. we didn't create a numpy array and then we chunked it!) @@ -2031,20 +2079,18 @@ def check(actual, expect_dtype, expect_values): assert not isinstance(v, np.ndarray) def test_zeros_like(self): - orig = Variable(dims=('x', 'y'), data=[[1.5, 2.0], [3.1, 4.3]], - attrs={'foo': 'bar'}) - assert_identical(zeros_like(orig), - full_like(orig, 0)) - assert_identical(zeros_like(orig, dtype=int), - full_like(orig, 0, dtype=int)) + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ) + assert_identical(zeros_like(orig), full_like(orig, 0)) + assert_identical(zeros_like(orig, dtype=int), full_like(orig, 0, dtype=int)) def test_ones_like(self): - orig = Variable(dims=('x', 'y'), data=[[1.5, 2.0], [3.1, 4.3]], - attrs={'foo': 'bar'}) - assert_identical(ones_like(orig), - full_like(orig, 1)) - assert_identical(ones_like(orig, dtype=int), - full_like(orig, 1, dtype=int)) + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ) + assert_identical(ones_like(orig), full_like(orig, 1)) + assert_identical(ones_like(orig, dtype=int), full_like(orig, 1, dtype=int)) def test_unsupported_type(self): # Non indexable type @@ -2056,17 +2102,17 @@ class CustomIndexable(CustomArray, indexing.ExplicitlyIndexed): pass array = CustomArray(np.arange(3)) - orig = Variable(dims=('x'), data=array, attrs={'foo': 'bar'}) + orig = Variable(dims=("x"), data=array, attrs={"foo": "bar"}) assert isinstance(orig._data, np.ndarray) # should not be CustomArray array = CustomIndexable(np.arange(3)) - orig = Variable(dims=('x'), data=array, attrs={'foo': 'bar'}) + orig = Variable(dims=("x"), data=array, attrs={"foo": "bar"}) assert isinstance(orig._data, CustomIndexable) def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: - Variable('x', [1, 2, np.NaN]) > 0 + Variable("x", [1, 2, np.NaN]) > 0 assert len(record) == 0 @@ -2078,68 +2124,68 @@ def setUp(self): self.d = np.random.random((10, 3)).astype(np.float64) def check_orthogonal_indexing(self, v): - assert np.allclose(v.isel(x=[8, 3], y=[2, 1]), - self.d[[8, 3]][:, [2, 1]]) + assert np.allclose(v.isel(x=[8, 3], y=[2, 1]), self.d[[8, 3]][:, [2, 1]]) def check_vectorized_indexing(self, v): - ind_x = Variable('z', [0, 2]) - ind_y = Variable('z', [2, 1]) + ind_x = Variable("z", [0, 2]) + ind_y = Variable("z", [2, 1]) assert np.allclose(v.isel(x=ind_x, y=ind_y), self.d[ind_x, ind_y]) def test_NumpyIndexingAdapter(self): - v = Variable(dims=('x', 'y'), data=NumpyIndexingAdapter(self.d)) + v = Variable(dims=("x", "y"), data=NumpyIndexingAdapter(self.d)) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # could not doubly wrapping - with raises_regex(TypeError, 'NumpyIndexingAdapter only wraps '): - v = Variable(dims=('x', 'y'), data=NumpyIndexingAdapter( - NumpyIndexingAdapter(self.d))) + with raises_regex(TypeError, "NumpyIndexingAdapter only wraps "): + v = Variable( + dims=("x", "y"), data=NumpyIndexingAdapter(NumpyIndexingAdapter(self.d)) + ) def test_LazilyOuterIndexedArray(self): - v = Variable(dims=('x', 'y'), data=LazilyOuterIndexedArray(self.d)) + v = Variable(dims=("x", "y"), data=LazilyOuterIndexedArray(self.d)) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping v = Variable( - dims=('x', 'y'), - data=LazilyOuterIndexedArray(LazilyOuterIndexedArray(self.d))) + dims=("x", "y"), + data=LazilyOuterIndexedArray(LazilyOuterIndexedArray(self.d)), + ) self.check_orthogonal_indexing(v) # hierarchical wrapping v = Variable( - dims=('x', 'y'), - data=LazilyOuterIndexedArray(NumpyIndexingAdapter(self.d))) + dims=("x", "y"), data=LazilyOuterIndexedArray(NumpyIndexingAdapter(self.d)) + ) self.check_orthogonal_indexing(v) def test_CopyOnWriteArray(self): - v = Variable(dims=('x', 'y'), data=CopyOnWriteArray(self.d)) + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(self.d)) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping v = Variable( - dims=('x', 'y'), - data=CopyOnWriteArray(LazilyOuterIndexedArray(self.d))) + dims=("x", "y"), data=CopyOnWriteArray(LazilyOuterIndexedArray(self.d)) + ) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) def test_MemoryCachedArray(self): - v = Variable(dims=('x', 'y'), data=MemoryCachedArray(self.d)) + v = Variable(dims=("x", "y"), data=MemoryCachedArray(self.d)) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping - v = Variable(dims=('x', 'y'), - data=CopyOnWriteArray(MemoryCachedArray(self.d))) + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(MemoryCachedArray(self.d))) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) @requires_dask def test_DaskIndexingAdapter(self): import dask.array as da + da = da.asarray(self.d) - v = Variable(dims=('x', 'y'), data=DaskIndexingAdapter(da)) + v = Variable(dims=("x", "y"), data=DaskIndexingAdapter(da)) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping - v = Variable(dims=('x', 'y'), - data=CopyOnWriteArray(DaskIndexingAdapter(da))) + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(DaskIndexingAdapter(da))) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 6056bb8b9ae..88ca8d3ab4f 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -1,10 +1,10 @@ -''' +""" Useful for: * users learning xarray * building tutorials in the documentation. -''' +""" import hashlib import os as _os from urllib.request import urlretrieve @@ -15,7 +15,7 @@ from .core.dataarray import DataArray from .core.dataset import Dataset -_default_cache_dir = _os.sep.join(('~', '.xarray_tutorial_data')) +_default_cache_dir = _os.sep.join(("~", ".xarray_tutorial_data")) def file_md5_checksum(fname): @@ -26,9 +26,14 @@ def file_md5_checksum(fname): # idea borrowed from Seaborn -def open_dataset(name, cache=True, cache_dir=_default_cache_dir, - github_url='https://github.com/pydata/xarray-data', - branch='master', **kws): +def open_dataset( + name, + cache=True, + cache_dir=_default_cache_dir, + github_url="https://github.com/pydata/xarray-data", + branch="master", + **kws +): """ Open a dataset from the online repository (requires internet). @@ -56,9 +61,9 @@ def open_dataset(name, cache=True, cache_dir=_default_cache_dir, """ longdir = _os.path.expanduser(cache_dir) - fullname = name + '.nc' + fullname = name + ".nc" localfile = _os.sep.join((longdir, fullname)) - md5name = name + '.md5' + md5name = name + ".md5" md5file = _os.sep.join((longdir, md5name)) if not _os.path.exists(localfile): @@ -68,13 +73,13 @@ def open_dataset(name, cache=True, cache_dir=_default_cache_dir, if not _os.path.isdir(longdir): _os.mkdir(longdir) - url = '/'.join((github_url, 'raw', branch, fullname)) + url = "/".join((github_url, "raw", branch, fullname)) urlretrieve(url, localfile) - url = '/'.join((github_url, 'raw', branch, md5name)) + url = "/".join((github_url, "raw", branch, md5name)) urlretrieve(url, md5file) localmd5 = file_md5_checksum(localfile) - with open(md5file, 'r') as f: + with open(md5file, "r") as f: remotemd5 = f.read() if localmd5 != remotemd5: _os.remove(localfile) @@ -106,23 +111,27 @@ def load_dataset(*args, **kwargs): def scatter_example_dataset(): - A = DataArray(np.zeros([3, 11, 4, 4]), - dims=['x', 'y', 'z', 'w'], - coords=[np.arange(3), - np.linspace(0, 1, 11), - np.arange(4), - 0.1 * np.random.randn(4)]) - B = 0.1 * A.x**2 + A.y**2.5 + 0.1 * A.z * A.w + A = DataArray( + np.zeros([3, 11, 4, 4]), + dims=["x", "y", "z", "w"], + coords=[ + np.arange(3), + np.linspace(0, 1, 11), + np.arange(4), + 0.1 * np.random.randn(4), + ], + ) + B = 0.1 * A.x ** 2 + A.y ** 2.5 + 0.1 * A.z * A.w A = -0.1 * A.x + A.y / (5 + A.z) + A.w - ds = Dataset({'A': A, 'B': B}) - ds['w'] = ['one', 'two', 'three', 'five'] + ds = Dataset({"A": A, "B": B}) + ds["w"] = ["one", "two", "three", "five"] - ds.x.attrs['units'] = 'xunits' - ds.y.attrs['units'] = 'yunits' - ds.z.attrs['units'] = 'zunits' - ds.w.attrs['units'] = 'wunits' + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.z.attrs["units"] = "zunits" + ds.w.attrs["units"] = "wunits" - ds.A.attrs['units'] = 'Aunits' - ds.B.attrs['units'] = 'Bunits' + ds.A.attrs["units"] = "Aunits" + ds.B.attrs["units"] = "Bunits" return ds diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index c4261d465e9..7b9ca1878f7 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -42,18 +42,21 @@ def __init__(self, name): self._name = name def __call__(self, *args, **kwargs): - if self._name not in ['angle', 'iscomplex']: + if self._name not in ["angle", "iscomplex"]: _warnings.warn( - 'xarray.ufuncs will be deprecated when xarray no longer ' - 'supports versions of numpy older than v1.17. Instead, use ' - 'numpy ufuncs directly.', - PendingDeprecationWarning, stacklevel=2) + "xarray.ufuncs will be deprecated when xarray no longer " + "supports versions of numpy older than v1.17. Instead, use " + "numpy ufuncs directly.", + PendingDeprecationWarning, + stacklevel=2, + ) new_args = args f = _dask_or_eager_func(self._name, array_args=slice(len(args))) if len(args) > 2 or len(args) == 0: - raise TypeError('cannot handle %s arguments for %r' % - (len(args), self._name)) + raise TypeError( + "cannot handle %s arguments for %r" % (len(args), self._name) + ) elif len(args) == 1: if isinstance(args[0], _xarray_types): f = args[0]._unary_op(self) @@ -68,8 +71,10 @@ def __call__(self, *args, **kwargs): new_args = tuple(reversed(args)) res = f(*new_args, **kwargs) if res is NotImplemented: - raise TypeError('%r not implemented for types (%r, %r)' - % (self._name, type(args[0]), type(args[1]))) + raise TypeError( + "%r not implemented for types (%r, %r)" + % (self._name, type(args[0]), type(args[1])) + ) return res @@ -77,11 +82,13 @@ def _create_op(name): func = _UFuncDispatcher(name) func.__name__ = name doc = getattr(_np, name).__doc__ - func.__doc__ = ('xarray specific variant of numpy.%s. Handles ' - 'xarray.Dataset, xarray.DataArray, xarray.Variable, ' - 'numpy.ndarray and dask.array.Array objects with ' - 'automatic dispatching.\n\n' - 'Documentation from numpy:\n\n%s' % (name, doc)) + func.__doc__ = ( + "xarray specific variant of numpy.%s. Handles " + "xarray.Dataset, xarray.DataArray, xarray.Variable, " + "numpy.ndarray and dask.array.Array objects with " + "automatic dispatching.\n\n" + "Documentation from numpy:\n\n%s" % (name, doc) + ) return func diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index d9e8c6a27bb..85bb9db8360 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -18,9 +18,11 @@ def get_sys_info(): commit = None if os.path.isdir(".git") and os.path.isdir("xarray"): try: - pipe = subprocess.Popen('git log --format="%H" -n 1'.split(" "), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + pipe = subprocess.Popen( + 'git log --format="%H" -n 1'.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) so, serr = pipe.communicate() except Exception: pass @@ -28,30 +30,30 @@ def get_sys_info(): if pipe.returncode == 0: commit = so try: - commit = so.decode('utf-8') + commit = so.decode("utf-8") except ValueError: pass commit = commit.strip().strip('"') - blob.append(('commit', commit)) + blob.append(("commit", commit)) try: - (sysname, nodename, release, - version, machine, processor) = platform.uname() - blob.extend([ - ("python", sys.version), - ("python-bits", struct.calcsize("P") * 8), - ("OS", "%s" % (sysname)), - ("OS-release", "%s" % (release)), - # ("Version", "%s" % (version)), - ("machine", "%s" % (machine)), - ("processor", "%s" % (processor)), - ("byteorder", "%s" % sys.byteorder), - ("LC_ALL", "%s" % os.environ.get('LC_ALL', "None")), - ("LANG", "%s" % os.environ.get('LANG', "None")), - ("LOCALE", "%s.%s" % locale.getlocale()), - - ]) + (sysname, nodename, release, version, machine, processor) = platform.uname() + blob.extend( + [ + ("python", sys.version), + ("python-bits", struct.calcsize("P") * 8), + ("OS", "%s" % (sysname)), + ("OS-release", "%s" % (release)), + # ("Version", "%s" % (version)), + ("machine", "%s" % (machine)), + ("processor", "%s" % (processor)), + ("byteorder", "%s" % sys.byteorder), + ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), + ("LANG", "%s" % os.environ.get("LANG", "None")), + ("LOCALE", "%s.%s" % locale.getlocale()), + ] + ) except Exception: pass @@ -63,15 +65,17 @@ def netcdf_and_hdf5_versions(): libnetcdf_version = None try: import netCDF4 + libhdf5_version = netCDF4.__hdf5libversion__ libnetcdf_version = netCDF4.__netcdf4libversion__ except ImportError: try: import h5py + libhdf5_version = h5py.version.hdf5_version except ImportError: pass - return [('libhdf5', libhdf5_version), ('libnetcdf', libnetcdf_version)] + return [("libhdf5", libhdf5_version), ("libnetcdf", libnetcdf_version)] def show_versions(file=sys.stdout): @@ -132,7 +136,7 @@ def show_versions(file=sys.stdout): ver = ver_f(mod) deps_blob.append((modname, ver)) except Exception: - deps_blob.append((modname, 'installed')) + deps_blob.append((modname, "installed")) print("\nINSTALLED VERSIONS", file=file) print("------------------", file=file) @@ -145,5 +149,5 @@ def show_versions(file=sys.stdout): print("%s: %s" % (k, stat), file=file) -if __name__ == '__main__': +if __name__ == "__main__": show_versions()