From b87b684b36cf5adbe4dca208aed0c69c44fc44c4 Mon Sep 17 00:00:00 2001 From: Brian Rose Date: Tue, 14 Aug 2018 14:44:21 -0400 Subject: [PATCH 01/51] Fix spelling -- change recieved to received (#2367) --- xarray/backends/api.py | 2 +- xarray/core/missing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index b2c0df7b01b..2bf13011bd1 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -806,7 +806,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, for obj in datasets: if not isinstance(obj, Dataset): raise TypeError('save_mfdataset only supports writing Dataset ' - 'objects, recieved type %s' % type(obj)) + 'objects, received type %s' % type(obj)) if groups is None: groups = [None] * len(datasets) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index bec9e2e1931..232fa185c07 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -57,7 +57,7 @@ def __init__(self, xi, yi, method='linear', fill_value=None, **kwargs): if self.cons_kwargs: raise ValueError( - 'recieved invalid kwargs: %r' % self.cons_kwargs.keys()) + 'received invalid kwargs: %r' % self.cons_kwargs.keys()) if fill_value is None: self._left = np.nan From c27ca436321654a97e776aa0d055dfef357bc5a8 Mon Sep 17 00:00:00 2001 From: Maximilian Maahn Date: Tue, 14 Aug 2018 18:18:27 -0600 Subject: [PATCH 02/51] Faster unstack (#2364) * Make dataset.unstack faster by skipping reindex if not necessary. * Remove prints, add comment * added asv benchmark for unstacking * Added test * Simplified test * Added whats-new entry * PEP8 * Made asv test faster --- asv_bench/benchmarks/unstacking.py | 25 +++++++++++++++++++++++++ doc/whats-new.rst | 4 ++++ xarray/core/dataset.py | 9 +++++++-- xarray/tests/test_dataset.py | 15 ++++++++++++++- 4 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 asv_bench/benchmarks/unstacking.py diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py new file mode 100644 index 00000000000..aa304d4eb40 --- /dev/null +++ b/asv_bench/benchmarks/unstacking.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import xarray as xr + +from . import requires_dask + + +class Unstacking(object): + def setup(self): + data = np.random.RandomState(0).randn(1, 1000, 500) + self.ds = xr.DataArray(data).stack(flat_dim=['dim_1', 'dim_2']) + + def time_unstack_fast(self): + self.ds.unstack('flat_dim') + + def time_unstack_slow(self): + self.ds[:, ::-1].unstack('flat_dim') + + +class UnstackingDask(Unstacking): + def setup(self, *args, **kwargs): + requires_dask() + super(UnstackingDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'flat_dim': 50}) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bf9536a5fc7..4552a4ca546 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -52,6 +52,10 @@ Enhancements (:issue:`2331`) By `Maximilian Roos `_. +- Applying ``unstack`` to a large DataArray or Dataset is now much faster if the MultiIndex has not been modified after stacking the indices. + (:issue:`1560`) + By `Maximilian Maahn `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4b52178ad0e..e6bc2f8aeaf 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2324,8 +2324,13 @@ def unstack(self, dim): 'a MultiIndex') full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) - obj = self.reindex(copy=False, **{dim: full_idx}) - + + # take a shortcut in case the MultiIndex was not modified. + if index.equals(full_idx): + obj = self + else: + obj = self.reindex(copy=False, **{dim: full_idx}) + new_dim_names = index.names new_dim_sizes = [lev.size for lev in index.levels] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 08d71d462d8..c67183db1ec 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2113,7 +2113,7 @@ def test_unstack_errors(self): with raises_regex(ValueError, 'does not have a MultiIndex'): ds.unstack('x') - def test_stack_unstack(self): + def test_stack_unstack_fast(self): ds = Dataset({'a': ('x', [0, 1]), 'b': (('x', 'y'), [[0, 1], [2, 3]]), 'x': [0, 1], @@ -2124,6 +2124,19 @@ def test_stack_unstack(self): actual = ds[['b']].stack(z=['x', 'y']).unstack('z') assert actual.identical(ds[['b']]) + def test_stack_unstack_slow(self): + ds = Dataset({'a': ('x', [0, 1]), + 'b': (('x', 'y'), [[0, 1], [2, 3]]), + 'x': [0, 1], + 'y': ['a', 'b']}) + stacked = ds.stack(z=['x', 'y']) + actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + assert actual.broadcast_equals(ds) + + stacked = ds[['b']].stack(z=['x', 'y']) + actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + assert actual.identical(ds[['b']]) + def test_update(self): data = create_test_data(seed=0) expected = data.copy() From cbb2aeb6492ad5364694396fb10e3b86abfe0aa6 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Wed, 15 Aug 2018 01:11:28 -0700 Subject: [PATCH 03/51] Add option to not roll coords (#2360) * Add option to not roll coords * Rename keyword arg and add tests * Add what's new * Fix passing None and add more tests * Revise from comments * Revise with cleaner version * Revisions based on comments * Fix either_dict_or_kwargs * Revisions from comments --- doc/whats-new.rst | 5 ++++ xarray/core/dataarray.py | 15 ++++++++---- xarray/core/dataset.py | 44 ++++++++++++++++++++++++---------- xarray/tests/test_dataarray.py | 19 +++++++++++++-- xarray/tests/test_dataset.py | 30 ++++++++++++++++++++--- 5 files changed, 92 insertions(+), 21 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4552a4ca546..47d39b967e3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -56,6 +56,11 @@ Enhancements (:issue:`1560`) By `Maximilian Maahn `_. +- You can now control whether or not to offset the coordinates when using + the ``roll`` method and the current behavior, coordinates rolled by default, + raises a deprecation warning unless explicitly setting the keyword argument. + (:issue:`1875`) + By `Andrew Huang `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f215bc47df8..b1be994416e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2015,14 +2015,20 @@ def shift(self, **shifts): variable = self.variable.shift(**shifts) return self._replace(variable) - def roll(self, **shifts): + def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this array by an offset along one or more dimensions. - Unlike shift, roll rotates all variables, including coordinates. The - direction of rotation is consistent with :py:func:`numpy.roll`. + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. Parameters ---------- + roll_coords : bool + Indicates whether to roll the coordinates by the offset + The current default of roll_coords (None, equivalent to True) is + deprecated and will change to False in a future version. + Explicitly pass roll_coords to silence the warning. **shifts : keyword arguments of the form {dim: offset} Integer offset to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. @@ -2046,7 +2052,8 @@ def roll(self, **shifts): Coordinates: * x (x) int64 2 0 1 """ - ds = self._to_temp_dataset().roll(**shifts) + ds = self._to_temp_dataset().roll( + shifts=shifts, roll_coords=roll_coords, **shifts_kwargs) return self._from_temp_dataset(ds) @property diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e6bc2f8aeaf..37544aca372 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2324,13 +2324,13 @@ def unstack(self, dim): 'a MultiIndex') full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) - + # take a shortcut in case the MultiIndex was not modified. if index.equals(full_idx): obj = self else: obj = self.reindex(copy=False, **{dim: full_idx}) - + new_dim_names = index.names new_dim_sizes = [lev.size for lev in index.levels] @@ -3360,18 +3360,28 @@ def shift(self, **shifts): return self._replace_vars_and_dims(variables) - def roll(self, **shifts): + def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this dataset by an offset along one or more dimensions. - Unlike shift, roll rotates all variables, including coordinates. The - direction of rotation is consistent with :py:func:`numpy.roll`. + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} - Integer offset to rotate each of the given dimensions. Positive - offsets roll to the right; negative offsets roll to the left. + shifts : dict, optional + A dict with keys matching dimensions and values given + by integers to rotate each of the given dimensions. Positive + offsets roll to the right; negative offsets roll to the left. + roll_coords : bool + Indicates whether to roll the coordinates by the offset + The current default of roll_coords (None, equivalent to True) is + deprecated and will change to False in a future version. + Explicitly pass roll_coords to silence the warning. + **shifts_kwargs : {dim: offset, ...}, optional + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. Returns ------- rolled : Dataset @@ -3394,15 +3404,25 @@ def roll(self, **shifts): Data variables: foo (x) object 'd' 'e' 'a' 'b' 'c' """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') invalid = [k for k in shifts if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) + if roll_coords is None: + warnings.warn("roll_coords will be set to False in the future." + " Explicitly set roll_coords to silence warning.", + FutureWarning, stacklevel=2) + roll_coords = True + + unrolled_vars = () if roll_coords else self.coords + variables = OrderedDict() - for name, var in iteritems(self.variables): - var_shifts = dict((k, v) for k, v in shifts.items() - if k in var.dims) - variables[name] = var.roll(**var_shifts) + for k, v in iteritems(self.variables): + if k not in unrolled_vars: + variables[k] = v.roll(**shifts) + else: + variables[k] = v return self._replace_vars_and_dims(variables) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 2950e97cc75..1a115192fb4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3099,9 +3099,24 @@ def test_shift(self): actual = arr.shift(x=offset) assert_identical(expected, actual) - def test_roll(self): + def test_roll_coords(self): arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') - actual = arr.roll(x=1) + actual = arr.roll(x=1, roll_coords=True) + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) + assert_identical(expected, actual) + + def test_roll_no_coords(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + actual = arr.roll(x=1, roll_coords=False) + expected = DataArray([3, 1, 2], coords=[('x', [0, 1, 2])]) + assert_identical(expected, actual) + + def test_roll_coords_none(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + + with pytest.warns(FutureWarning): + actual = arr.roll(x=1, roll_coords=None) + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) assert_identical(expected, actual) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c67183db1ec..e2a406b1e51 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3862,18 +3862,42 @@ def test_shift(self): with raises_regex(ValueError, 'dimensions'): ds.shift(foo=123) - def test_roll(self): + def test_roll_coords(self): coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} attrs = {'meta': 'data'} ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) - actual = ds.roll(x=1) + actual = ds.roll(x=1, roll_coords=True) ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) with raises_regex(ValueError, 'dimensions'): - ds.roll(foo=123) + ds.roll(foo=123, roll_coords=True) + + def test_roll_no_coords(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + actual = ds.roll(x=1, roll_coords=False) + + expected = Dataset({'foo': ('x', [3, 1, 2])}, coords, attrs) + assert_identical(expected, actual) + + with raises_regex(ValueError, 'dimensions'): + ds.roll(abc=321, roll_coords=False) + + def test_roll_coords_none(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + + with pytest.warns(FutureWarning): + actual = ds.roll(x=1, roll_coords=None) + + ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} + expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) + assert_identical(expected, actual) def test_real_and_imag(self): attrs = {'foo': 'bar'} From 5155ef9ed2be7b3b201925c4902e8d633fec87a8 Mon Sep 17 00:00:00 2001 From: tv3141 Date: Thu, 16 Aug 2018 00:05:40 +0100 Subject: [PATCH 04/51] uncomment test (#2369) --- .travis.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 951b151d829..0e51e946da0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -95,9 +95,7 @@ install: - python xarray/util/print_versions.py script: - # TODO: restore this check once the upstream pandas issue is fixed: - # https://github.com/pandas-dev/pandas/issues/21071 - # - python -OO -c "import xarray" + - python -OO -c "import xarray" - if [[ "$CONDA_ENV" == "docs" ]]; then conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc; sphinx-build -n -j auto -b html -d _build/doctrees doc _build/html; From 0b9ab2d12ae866a27050724d94facae6e56f5927 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Thu, 16 Aug 2018 15:59:32 +0900 Subject: [PATCH 05/51] Refactor nanops (#2236) * Inhouse nanops * Cleanup nanops * remove NAT_TYPES * flake8. * another flake8 * recover nat types * remove keep_dims option from nanops (to make them compatible with numpy==1.11). * Test aggregation over multiple dimensions * Remove print. * Docs. More cleanup. * flake8 * Bug fix. Better test coverage. * using isnull, where_method. Remove unnecessary conditional branching. * More refactoring based on the comments * remove dtype from nanmedian * Fix for nanmedian * Add tests for dataset * Add tests with resample. * lint * updated whatsnew * Revise from comments. * Use .any and .all method instead of np.any / np.all * Avoid using numpy methods * Avoid casting to int for dask array * Update whatsnew --- doc/whats-new.rst | 9 ++ xarray/core/common.py | 39 +++--- xarray/core/dtypes.py | 3 + xarray/core/duck_array_ops.py | 186 ++++++------------------- xarray/core/nanops.py | 208 ++++++++++++++++++++++++++++ xarray/core/nputils.py | 41 ++++++ xarray/core/ops.py | 14 +- xarray/tests/test_dataarray.py | 10 +- xarray/tests/test_dataset.py | 16 ++- xarray/tests/test_duck_array_ops.py | 172 +++++++++++++++++++++-- xarray/tests/test_variable.py | 4 +- 11 files changed, 519 insertions(+), 183 deletions(-) create mode 100644 xarray/core/nanops.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 47d39b967e3..4725fe74577 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,12 @@ Documentation Enhancements ~~~~~~~~~~~~ +- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`, + :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and + :py:meth:`~xarray.Dataset.prod`. + (:issue:`2230`) + By `Keisuke Fujii `_. + - :py:meth:`plot()` now accepts the kwargs ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. By `Deepak Cherian `_. (:issue:`2224`) @@ -78,6 +84,9 @@ Bug fixes - Tests can be run in parallel with pytest-xdist By `Tony Tung `_. +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + - Now raises a ValueError when there is a conflict between dimension names and level names of MultiIndex. (:issue:`2299`) By `Keisuke Fujii `_. diff --git a/xarray/core/common.py b/xarray/core/common.py index 3f934fcc769..55aca5f557f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -2,6 +2,7 @@ import warnings from distutils.version import LooseVersion +from textwrap import dedent import numpy as np import pandas as pd @@ -27,20 +28,20 @@ def wrapped_func(self, dim=None, axis=None, keep_attrs=False, allow_lazy=True, **kwargs) return wrapped_func - _reduce_extra_args_docstring = \ - """dim : str or sequence of str, optional + _reduce_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension(s) over which to apply `{name}`. axis : int or sequence of int, optional Axis(es) over which to apply `{name}`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then - `{name}` is calculated over axes.""" + `{name}` is calculated over axes.""") - _cum_extra_args_docstring = \ - """dim : str or sequence of str, optional + _cum_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension over which to apply `{name}`. axis : int or sequence of int, optional Axis over which to apply `{name}`. Only one of the 'dim' - and 'axis' arguments can be supplied.""" + and 'axis' arguments can be supplied.""") class ImplementsDatasetReduce(object): @@ -308,12 +309,12 @@ def assign_coords(self, **kwargs): assigned : same type as caller A new object with the new coordinates in addition to the existing data. - + Examples -------- - + Convert longitude coordinates from 0-359 to -180-179: - + >>> da = xr.DataArray(np.random.rand(4), ... coords=[np.array([358, 359, 0, 1])], ... dims='lon') @@ -445,11 +446,11 @@ def groupby(self, group, squeeze=True): grouped : GroupBy A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. - + Examples -------- Calculate daily anomalies for daily data: - + >>> da = xr.DataArray(np.linspace(0, 1826, num=1827), ... coords=[pd.date_range('1/1/2000', '31/12/2004', ... freq='D')], @@ -465,7 +466,7 @@ def groupby(self, group, squeeze=True): Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ... - + See Also -------- core.groupby.DataArrayGroupBy @@ -589,7 +590,7 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, closed=None, label=None, base=0, keep_attrs=False, **indexer): """Returns a Resample object for performing resampling operations. - Handles both downsampling and upsampling. If any intervals contain no + Handles both downsampling and upsampling. If any intervals contain no values from the original object, they will be given the value ``NaN``. Parameters @@ -616,11 +617,11 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, ------- resampled : same type as caller This object resampled. - + Examples -------- Downsample monthly time-series data to seasonal data: - + >>> da = xr.DataArray(np.linspace(0, 11, num=12), ... coords=[pd.date_range('15/12/1999', ... periods=12, freq=pd.DateOffset(months=1))], @@ -637,13 +638,13 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 Upsample monthly time-series data to daily data: - + >>> da.resample(time='1D').interpolate('linear') array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ]) Coordinates: * time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ... - + References ---------- @@ -957,8 +958,8 @@ def contains_cftime_datetimes(var): sample = sample.item() return isinstance(sample, cftime_datetime) else: - return False - + return False + def _contains_datetime_like_objects(var): """Check if a variable contains datetime like objects (either diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 7326b936e2e..7ad44472f06 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -98,6 +98,9 @@ def maybe_promote(dtype): return np.dtype(dtype), fill_value +NAT_TYPES = (np.datetime64('NaT'), np.timedelta64('NaT')) + + def get_fill_value(dtype): """Return an appropriate fill value for this dtype. diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 3bd105064da..17eb310f8db 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -17,14 +17,6 @@ from .nputils import nanfirst, nanlast from .pycompat import dask_array_type -try: - import bottleneck as bn - has_bottleneck = True -except ImportError: - # use numpy methods instead - bn = np - has_bottleneck = False - try: import dask.array as dask_array from . import dask_array_compat @@ -175,7 +167,7 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes """ - return sum(~isnull(data), axis=axis) + return np.sum(~isnull(data), axis=axis) def where(condition, x, y): @@ -213,159 +205,69 @@ def _ignore_warnings_if(condition): yield -def _nansum_object(value, axis=None, **kwargs): - """ In house nansum for object array """ - value = fillna(value, 0) - return _dask_or_eager_func('sum')(value, axis=axis, **kwargs) - - -def _nan_minmax_object(func, get_fill_value, value, axis=None, **kwargs): - """ In house nanmin and nanmax for object array """ - fill_value = get_fill_value(value.dtype) - valid_count = count(value, axis=axis) - filled_value = fillna(value, fill_value) - data = _dask_or_eager_func(func)(filled_value, axis=axis, **kwargs) - if not hasattr(data, 'dtype'): # scalar case - data = dtypes.fill_value(value.dtype) if valid_count == 0 else data - return np.array(data, dtype=value.dtype) - return where_method(data, valid_count != 0) - - -def _nan_argminmax_object(func, get_fill_value, value, axis=None, **kwargs): - """ In house nanargmin, nanargmax for object arrays. Always return integer - type """ - fill_value = get_fill_value(value.dtype) - valid_count = count(value, axis=axis) - value = fillna(value, fill_value) - data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) - # dask seems return non-integer type - if isinstance(value, dask_array_type): - data = data.astype(int) - - if (valid_count == 0).any(): - raise ValueError('All-NaN slice encountered') - - return np.array(data, dtype=int) - - -def _nanmean_ddof_object(ddof, value, axis=None, **kwargs): - """ In house nanmean. ddof argument will be used in _nanvar method """ - valid_count = count(value, axis=axis) - value = fillna(value, 0) - # As dtype inference is impossible for object dtype, we assume float - # https://github.com/dask/dask/issues/3162 - dtype = kwargs.pop('dtype', None) - if dtype is None and value.dtype.kind == 'O': - dtype = value.dtype if value.dtype.kind in ['cf'] else float - - data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs) - data = data / (valid_count - ddof) - return where_method(data, valid_count != 0) - - -def _nanvar_object(value, axis=None, **kwargs): - ddof = kwargs.pop('ddof', 0) - kwargs_mean = kwargs.copy() - kwargs_mean.pop('keepdims', None) - value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis, - keepdims=True, **kwargs_mean) - squared = (value.astype(value_mean.dtype) - value_mean)**2 - return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs) - - -_nan_object_funcs = { - 'sum': _nansum_object, - 'min': partial(_nan_minmax_object, 'min', dtypes.get_pos_infinity), - 'max': partial(_nan_minmax_object, 'max', dtypes.get_neg_infinity), - 'argmin': partial(_nan_argminmax_object, 'argmin', - dtypes.get_pos_infinity), - 'argmax': partial(_nan_argminmax_object, 'argmax', - dtypes.get_neg_infinity), - 'mean': partial(_nanmean_ddof_object, 0), - 'var': _nanvar_object, -} - - -def _create_nan_agg_method(name, numeric_only=False, np_compat=False, - no_bottleneck=False, coerce_strings=False): +def _create_nan_agg_method(name, coerce_strings=False): + from . import nanops + def f(values, axis=None, skipna=None, **kwargs): if kwargs.pop('out', None) is not None: raise TypeError('`out` is not valid for {}'.format(name)) - # If dtype is supplied, we use numpy's method. - dtype = kwargs.get('dtype', None) values = asarray(values) - # dask requires dtype argument for object dtype - if (values.dtype == 'object' and name in ['sum', ]): - kwargs['dtype'] = values.dtype if dtype is None else dtype - if coerce_strings and values.dtype.kind in 'SU': values = values.astype(object) + func = None if skipna or (skipna is None and values.dtype.kind in 'cfO'): - if values.dtype.kind not in ['u', 'i', 'f', 'c']: - func = _nan_object_funcs.get(name, None) - using_numpy_nan_func = True - if func is None or values.dtype.kind not in 'Ob': - raise NotImplementedError( - 'skipna=True not yet implemented for %s with dtype %s' - % (name, values.dtype)) - else: - nanname = 'nan' + name - if (isinstance(axis, tuple) or not values.dtype.isnative or - no_bottleneck or (dtype is not None and - np.dtype(dtype) != values.dtype)): - # bottleneck can't handle multiple axis arguments or - # non-native endianness - if np_compat: - eager_module = npcompat - else: - eager_module = np - else: - kwargs.pop('dtype', None) - eager_module = bn - func = _dask_or_eager_func(nanname, eager_module) - using_numpy_nan_func = (eager_module is np or - eager_module is npcompat) + nanname = 'nan' + name + func = getattr(nanops, nanname) else: func = _dask_or_eager_func(name) - using_numpy_nan_func = False - with _ignore_warnings_if(using_numpy_nan_func): - try: - return func(values, axis=axis, **kwargs) - except AttributeError: - if isinstance(values, dask_array_type): - try: # dask/dask#3133 dask sometimes needs dtype argument - return func(values, axis=axis, dtype=values.dtype, - **kwargs) - except AttributeError: - msg = '%s is not yet implemented on dask arrays' % name - else: - assert using_numpy_nan_func - msg = ('%s is not available with skipna=False with the ' - 'installed version of numpy; upgrade to numpy 1.12 ' - 'or newer to use skipna=True or skipna=None' % name) - raise NotImplementedError(msg) - f.numeric_only = numeric_only + + try: + return func(values, axis=axis, **kwargs) + except AttributeError: + if isinstance(values, dask_array_type): + try: # dask/dask#3133 dask sometimes needs dtype argument + # if func does not accept dtype, then raises TypeError + return func(values, axis=axis, dtype=values.dtype, + **kwargs) + except (AttributeError, TypeError): + msg = '%s is not yet implemented on dask arrays' % name + else: + msg = ('%s is not available with skipna=False with the ' + 'installed version of numpy; upgrade to numpy 1.12 ' + 'or newer to use skipna=True or skipna=None' % name) + raise NotImplementedError(msg) + f.__name__ = name return f +# Attributes `numeric_only`, `available_min_count` is used for docs. +# See ops.inject_reduce_methods argmax = _create_nan_agg_method('argmax', coerce_strings=True) argmin = _create_nan_agg_method('argmin', coerce_strings=True) max = _create_nan_agg_method('max', coerce_strings=True) min = _create_nan_agg_method('min', coerce_strings=True) -sum = _create_nan_agg_method('sum', numeric_only=True) -mean = _create_nan_agg_method('mean', numeric_only=True) -std = _create_nan_agg_method('std', numeric_only=True) -var = _create_nan_agg_method('var', numeric_only=True) -median = _create_nan_agg_method('median', numeric_only=True) -prod = _create_nan_agg_method('prod', numeric_only=True, no_bottleneck=True) -cumprod_1d = _create_nan_agg_method( - 'cumprod', numeric_only=True, no_bottleneck=True) -cumsum_1d = _create_nan_agg_method( - 'cumsum', numeric_only=True, no_bottleneck=True) +sum = _create_nan_agg_method('sum') +sum.numeric_only = True +sum.available_min_count = True +mean = _create_nan_agg_method('mean') +mean.numeric_only = True +std = _create_nan_agg_method('std') +std.numeric_only = True +var = _create_nan_agg_method('var') +var.numeric_only = True +median = _create_nan_agg_method('median') +median.numeric_only = True +prod = _create_nan_agg_method('prod') +prod.numeric_only = True +sum.available_min_count = True +cumprod_1d = _create_nan_agg_method('cumprod') +cumprod_1d.numeric_only = True +cumsum_1d = _create_nan_agg_method('cumsum') +cumsum_1d.numeric_only = True def _nd_cum_func(cum_func, array, axis, **kwargs): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py new file mode 100644 index 00000000000..2309ed9619d --- /dev/null +++ b/xarray/core/nanops.py @@ -0,0 +1,208 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from . import dtypes +from .pycompat import dask_array_type +from . duck_array_ops import (count, isnull, fillna, where_method, + _dask_or_eager_func) +from . import nputils + +try: + import dask.array as dask_array +except ImportError: + dask_array = None + + +def _replace_nan(a, val): + """ + replace nan in a by val, and returns the replaced array and the nan + position + """ + mask = isnull(a) + return where_method(val, mask, a), mask + + +def _maybe_null_out(result, axis, mask, min_count=1): + """ + xarray version of pandas.core.nanops._maybe_null_out + """ + if hasattr(axis, '__len__'): # if tuple or list + raise ValueError('min_count is not available for reduction ' + 'with more than one dimensions.') + + if axis is not None and getattr(result, 'ndim', False): + null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + if null_mask.any(): + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = result.astype(dtype) + result[null_mask] = fill_value + + elif getattr(result, 'dtype', None) not in dtypes.NAT_TYPES: + null_mask = mask.size - mask.sum() + if null_mask < min_count: + result = np.nan + + return result + + +def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanargmin, nanargmax for object arrays. Always return integer + type + """ + valid_count = count(value, axis=axis) + value = fillna(value, fill_value) + data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) + + # TODO This will evaluate dask arrays and might be costly. + if (valid_count == 0).any(): + raise ValueError('All-NaN slice encountered') + + return data + + +def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanmin and nanmax for object array """ + valid_count = count(value, axis=axis) + filled_value = fillna(value, fill_value) + data = getattr(np, func)(filled_value, axis=axis, **kwargs) + if not hasattr(data, 'dtype'): # scalar case + data = dtypes.fill_value(value.dtype) if valid_count == 0 else data + return np.array(data, dtype=value.dtype) + return where_method(data, valid_count != 0) + + +def nanmin(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'min', dtypes.get_pos_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmin(a, axis=axis) + + +def nanmax(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'max', dtypes.get_neg_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmax(a, axis=axis) + + +def nanargmin(a, axis=None): + fill_value = dtypes.get_pos_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmin', fill_value, a, axis=axis) + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmin(a, axis=axis) + else: + res = np.argmin(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nanargmax(a, axis=None): + fill_value = dtypes.get_neg_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmax', fill_value, a, axis=axis) + + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmax(a, axis=axis) + else: + res = np.argmax(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nansum(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 0) + result = _dask_or_eager_func('sum')(a, axis=axis, dtype=dtype) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def _nanmean_ddof_object(ddof, value, axis=None, **kwargs): + """ In house nanmean. ddof argument will be used in _nanvar method """ + from .duck_array_ops import (count, fillna, _dask_or_eager_func, + where_method) + + valid_count = count(value, axis=axis) + value = fillna(value, 0) + # As dtype inference is impossible for object dtype, we assume float + # https://github.com/dask/dask/issues/3162 + dtype = kwargs.pop('dtype', None) + if dtype is None and value.dtype.kind == 'O': + dtype = value.dtype if value.dtype.kind in ['cf'] else float + + data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs) + data = data / (valid_count - ddof) + return where_method(data, valid_count != 0) + + +def nanmean(a, axis=None, dtype=None, out=None): + if a.dtype.kind == 'O': + return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) + + if isinstance(a, dask_array_type): + return dask_array.nanmean(a, axis=axis, dtype=dtype) + + return np.nanmean(a, axis=axis, dtype=dtype) + + +def nanmedian(a, axis=None, out=None): + return _dask_or_eager_func('nanmedian', eager_module=nputils)(a, axis=axis) + + +def _nanvar_object(value, axis=None, **kwargs): + ddof = kwargs.pop('ddof', 0) + kwargs_mean = kwargs.copy() + kwargs_mean.pop('keepdims', None) + value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis, + keepdims=True, **kwargs_mean) + squared = (value.astype(value_mean.dtype) - value_mean)**2 + return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs) + + +def nanvar(a, axis=None, dtype=None, out=None, ddof=0): + if a.dtype.kind == 'O': + return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) + + return _dask_or_eager_func('nanvar', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanstd(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nanstd', eager_module=nputils)( + a, axis=axis, dtype=dtype) + + +def nanprod(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 1) + result = _dask_or_eager_func('nanprod')(a, axis=axis, dtype=dtype, out=out) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def nancumsum(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumsum', eager_module=nputils)( + a, axis=axis, dtype=dtype) + + +def nancumprod(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumprod', eager_module=nputils)( + a, axis=axis, dtype=dtype) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 6df2d34bfe3..a8d596abd86 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -5,6 +5,14 @@ import numpy as np import pandas as pd +try: + import bottleneck as bn + _USE_BOTTLENECK = True +except ImportError: + # use numpy methods instead + bn = np + _USE_BOTTLENECK = False + def _validate_axis(data, axis): ndim = data.ndim @@ -195,3 +203,36 @@ def _rolling_window(a, window, axis=-1): rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides, writeable=False) return np.swapaxes(rolling, -2, axis) + + +def _create_bottleneck_method(name, npmodule=np): + def f(values, axis=None, **kwds): + dtype = kwds.get('dtype', None) + bn_func = getattr(bn, name, None) + + if (_USE_BOTTLENECK and bn_func is not None and + not isinstance(axis, tuple) and + values.dtype.kind in 'uifc' and + values.dtype.isnative and + (dtype is None or np.dtype(dtype) == values.dtype)): + # bottleneck does not take care dtype, min_count + kwds.pop('dtype', None) + result = bn_func(values, axis=axis, **kwds) + else: + result = getattr(npmodule, name)(values, axis=axis, **kwds) + + return result + + f.__name__ = name + return f + + +nanmin = _create_bottleneck_method('nanmin') +nanmax = _create_bottleneck_method('nanmax') +nanmean = _create_bottleneck_method('nanmean') +nanmedian = _create_bottleneck_method('nanmedian') +nanvar = _create_bottleneck_method('nanvar') +nanstd = _create_bottleneck_method('nanstd') +nanprod = _create_bottleneck_method('nanprod') +nancumsum = _create_bottleneck_method('nancumsum') +nancumprod = _create_bottleneck_method('nancumprod') diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d9e8ceb65d5..a0dd2212a8f 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -86,7 +86,7 @@ If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). + implemented (object, datetime64 or timedelta64).{min_count_docs} keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -102,6 +102,12 @@ indicated dimension(s) removed. """ +_MINCOUNT_DOCSTRING = """ +min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None.""" + _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this {da_or_ds}'s data windows by applying `{name}` along its dimension. @@ -236,11 +242,15 @@ def inject_reduce_methods(cls): [('count', duck_array_ops.count, False)]) for name, f, include_skipna in methods: numeric_only = getattr(f, 'numeric_only', False) + available_min_count = getattr(f, 'available_min_count', False) + min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else '' + func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format( name=name, cls=cls.__name__, - extra_args=cls._reduce_extra_args_docstring.format(name=name)) + extra_args=cls._reduce_extra_args_docstring.format(name=name), + min_count_docs=min_count_docs) setattr(cls, name, func) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 1a115192fb4..3619688d091 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3344,7 +3344,9 @@ def test_isin(da): def test_rolling_iter(da): rolling_obj = da.rolling(time=7) - rolling_obj_mean = rolling_obj.mean() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + rolling_obj_mean = rolling_obj.mean() assert len(rolling_obj.window_labels) == len(da['time']) assert_identical(rolling_obj.window_labels, da['time']) @@ -3352,8 +3354,10 @@ def test_rolling_iter(da): for i, (label, window_da) in enumerate(rolling_obj): assert label == da['time'].isel(time=i) - actual = rolling_obj_mean.isel(time=i) - expected = window_da.mean('time') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + actual = rolling_obj_mean.isel(time=i) + expected = window_da.mean('time') # TODO add assert_allclose_with_nan, which compares nan position # as well as the closeness of the values. diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e2a406b1e51..d73632c10a7 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2740,6 +2740,20 @@ def test_resample_and_first(self): result = actual.reduce(method) assert_equal(expected, result) + def test_resample_min_count(self): + times = pd.date_range('2000-01-01', freq='6H', periods=10) + ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), + 'bar': ('time', np.random.randn(10), {'meta': 'data'}), + 'time': times}) + # inject nan + ds['foo'] = xr.where(ds['foo'] > 2.0, np.nan, ds['foo']) + + actual = ds.resample(time='1D').sum(min_count=1) + expected = xr.concat([ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum('time', min_count=1) + for i in range(3)], dim=actual['time']) + assert_equal(expected, actual) + def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), @@ -3378,7 +3392,6 @@ def test_reduce(self): (('dim2', 'time'), ['dim1', 'dim3']), ((), ['dim1', 'dim2', 'dim3', 'time'])]: actual = data.min(dim=reduct).dims - print(reduct, actual, expected) self.assertItemsEqual(actual, expected) assert_equal(data.mean(dim=[]), data) @@ -3433,7 +3446,6 @@ def test_reduce_cumsum_test_dims(self): ('time', ['dim1', 'dim2', 'dim3']) ]: actual = getattr(data, cumfunc)(dim=reduct).dims - print(reduct, actual, expected) self.assertItemsEqual(actual, expected) def test_reduce_non_numeric(self): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index f3f93491822..3f32fc49fd2 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1,12 +1,16 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion + import numpy as np +import pandas as pd import pytest +from textwrap import dedent from numpy import array, nan import warnings -from xarray import DataArray, concat -from xarray.core import duck_array_ops +from xarray import DataArray, Dataset, concat +from xarray.core import duck_array_ops, dtypes from xarray.core.duck_array_ops import ( array_notnull_equiv, concatenate, count, first, last, mean, rolling_window, stack, where) @@ -100,7 +104,10 @@ def test_concatenate_type_promotion(self): assert_array_equal(result, np.array([1, 'b'], dtype=object)) def test_all_nan_arrays(self): - assert np.isnan(mean([np.nan, np.nan])) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'All-NaN slice') + warnings.filterwarnings('ignore', 'Mean of empty slice') + assert np.isnan(mean([np.nan, np.nan])) def test_cumsum_1d(): @@ -197,10 +204,15 @@ def construct_dataarray(dim_num, dtype, contains_nan, dask): array = rng.choice(['a', 'b', 'c', 'd'], size=shapes) else: raise ValueError - da = DataArray(array, dims=dims, coords={'x': np.arange(16)}, name='da') if contains_nan: - da = da.reindex(x=np.arange(20)) + inds = rng.choice(range(array.size), int(array.size * 0.2)) + dtype, fill_value = dtypes.maybe_promote(array.dtype) + array = array.astype(dtype) + array.flat[inds] = fill_value + + da = DataArray(array, dims=dims, coords={'x': np.arange(16)}, name='da') + if dask and has_dask: chunks = {d: 4 for d in dims} da = da.chunk(chunks) @@ -234,10 +246,16 @@ def series_reduce(da, func, dim, **kwargs): return concat(da1, dim=d) +def assert_dask_array(da, dask): + if dask and da.ndim > 0: + assert isinstance(da.data, dask_array_type) + + @pytest.mark.parametrize('dim_num', [1, 2]) @pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) @pytest.mark.parametrize('dask', [False, True]) @pytest.mark.parametrize('func', ['sum', 'min', 'max', 'mean', 'var']) +# TODO test cumsum, cumprod @pytest.mark.parametrize('skipna', [False, True]) @pytest.mark.parametrize('aggdim', [None, 'x']) def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): @@ -251,6 +269,9 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): if dask and not has_dask: pytest.skip('requires dask') + if dask and skipna is False and dtype in [np.bool_]: + pytest.skip('dask does not compute object-typed array') + rtol = 1e-04 if dtype == np.float32 else 1e-05 da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) @@ -259,6 +280,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # TODO: remove these after resolving # https://github.com/dask/dask/issues/3245 with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') warnings.filterwarnings('ignore', 'All-NaN slice') warnings.filterwarnings('ignore', 'invalid value encountered in') @@ -272,6 +294,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): expected = getattr(np, func)(da.values, axis=axis) actual = getattr(da, func)(skipna=skipna, dim=aggdim) + assert_dask_array(actual, dask) assert np.allclose(actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True) except (TypeError, AttributeError, ZeroDivisionError): @@ -279,14 +302,21 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # nanmean for object dtype pass - # make sure the compatiblility with pandas' results. actual = getattr(da, func)(skipna=skipna, dim=aggdim) + + # for dask case, make sure the result is the same for numpy backend + expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the compatiblility with pandas' results. if func == 'var': expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=0) assert_allclose(actual, expected, rtol=rtol) # also check ddof!=0 case actual = getattr(da, func)(skipna=skipna, dim=aggdim, ddof=5) + if dask: + assert isinstance(da.data, dask_array_type) expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=5) assert_allclose(actual, expected, rtol=rtol) @@ -297,11 +327,14 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # make sure the dtype argument if func not in ['max', 'min']: actual = getattr(da, func)(skipna=skipna, dim=aggdim, dtype=float) + assert_dask_array(actual, dask) assert actual.dtype == float # without nan da = construct_dataarray(dim_num, dtype, contains_nan=False, dask=dask) actual = getattr(da, func)(skipna=skipna) + if dask: + assert isinstance(da.data, dask_array_type) expected = getattr(np, 'nan{}'.format(func))(da.values) if actual.dtype == object: assert actual.values == np.array(expected) @@ -338,13 +371,6 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'All-NaN slice') - if aggdim == 'y' and contains_nan and skipna: - with pytest.raises(ValueError): - actual = da.isel(**{ - aggdim: getattr(da, 'arg' + func)( - dim=aggdim, skipna=skipna).compute()}) - return - actual = da.isel(**{aggdim: getattr(da, 'arg' + func) (dim=aggdim, skipna=skipna).compute()}) expected = getattr(da, func)(dim=aggdim, skipna=skipna) @@ -354,6 +380,7 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): def test_argmin_max_error(): da = construct_dataarray(2, np.bool_, contains_nan=True, dask=False) + da[0] = np.nan with pytest.raises(ValueError): da.argmin(dim='y') @@ -388,3 +415,122 @@ def test_dask_rolling(axis, window, center): with pytest.raises(ValueError): rolling_window(dx, axis=axis, window=100, center=center, fill_value=np.nan) + + +@pytest.mark.parametrize('dim_num', [1, 2]) +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +@pytest.mark.parametrize('aggdim', [None, 'x']) +def test_min_count(dim_num, dtype, dask, func, aggdim): + if dask and not has_dask: + pytest.skip('requires dask') + + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + min_count = 3 + + actual = getattr(da, func)(dim=aggdim, skipna=True, min_count=min_count) + + if LooseVersion(pd.__version__) >= LooseVersion('0.22.0'): + # min_count is only implenented in pandas > 0.22 + expected = series_reduce(da, func, skipna=True, dim=aggdim, + min_count=min_count) + assert_allclose(actual, expected) + + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_min_count_dataset(func): + da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) + ds = Dataset({'var1': da}, coords={'scalar': 0}) + actual = getattr(ds, func)(dim='x', skipna=True, min_count=3)['var1'] + expected = getattr(ds['var1'], func)(dim='x', skipna=True, min_count=3) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_multiple_dims(dtype, dask, func): + if dask and not has_dask: + pytest.skip('requires dask') + da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) + + actual = getattr(da, func)(('x', 'y')) + expected = getattr(getattr(da, func)('x'), func)('y') + assert_allclose(actual, expected) + + +def test_docs(): + # with min_count + actual = DataArray.sum.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `sum` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `sum`. + axis : int or sequence of int, optional + Axis(es) over which to apply `sum`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `sum` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `sum` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `sum` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected + + # without min_count + actual = DataArray.std.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `std` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `std`. + axis : int or sequence of int, optional + Axis(es) over which to apply `std`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `std` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `std` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `std` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index cdb578aff6c..3db5e6adc4b 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1503,8 +1503,8 @@ def test_reduce_funcs(self): assert_identical(v.all(dim='x'), Variable([], False)) v = Variable('t', pd.date_range('2000-01-01', periods=3)) - with pytest.raises(NotImplementedError): - v.argmax(skipna=True) + assert v.argmax(skipna=True) == 2 + assert_identical( v.max(), Variable([], pd.Timestamp('2000-01-03'))) From 725bd57ffa64d7e391ceef2b056fa8122ec09e8d Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Mon, 20 Aug 2018 10:12:36 +0900 Subject: [PATCH 06/51] More support of non-string dimension names (#2373) * More support for non-string dimension. * Avoid using kwargs in rolling. * Restore assign_coords, fixes typo --- xarray/core/common.py | 24 +++++++------ xarray/core/dataarray.py | 55 ++++++++++++++++++---------- xarray/core/dataset.py | 78 ++++++++++++++++++++++++++-------------- xarray/core/rolling.py | 53 ++++++++++++++------------- xarray/core/variable.py | 47 ++++++++++++++++++------ 5 files changed, 162 insertions(+), 95 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 55aca5f557f..280034a30dd 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -10,7 +10,7 @@ from . import duck_array_ops, dtypes, formatting, ops from .arithmetic import SupportsArithmetic from .pycompat import OrderedDict, basestring, dask_array_type, suppress -from .utils import Frozen, SortedKeysDict +from .utils import either_dict_or_kwargs, Frozen, SortedKeysDict class ImplementsArrayReduce(object): @@ -526,24 +526,24 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3, 'precision': precision, 'include_lowest': include_lowest}) - def rolling(self, min_periods=None, center=False, **windows): + def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs): """ Rolling window object. Parameters ---------- + dim: dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. + **dim_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or dim_kwarg must be provided. Returns ------- @@ -582,9 +582,9 @@ def rolling(self, min_periods=None, center=False, **windows): core.rolling.DataArrayRolling core.rolling.DatasetRolling """ - - return self._rolling_cls(self, min_periods=min_periods, - center=center, **windows) + dim = either_dict_or_kwargs(dim, dim_kwargs, 'rolling') + return self._rolling_cls(self, dim, min_periods=min_periods, + center=center) def resample(self, freq=None, dim=None, how=None, skipna=None, closed=None, label=None, base=0, keep_attrs=False, **indexer): @@ -650,6 +650,8 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ + # TODO support non-string indexer after removing the old API. + from .dataarray import DataArray from .resample import RESAMPLE_DIM diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b1be994416e..359812f2cc3 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -779,9 +779,9 @@ def sel(self, indexers=None, method=None, tolerance=None, drop=False, DataArray.isel """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel') ds = self._to_temp_dataset().sel( - indexers=indexers, drop=drop, method=method, tolerance=tolerance) + indexers=indexers, drop=drop, method=method, tolerance=tolerance, + **indexers_kwargs) return self._from_temp_dataset(ds) def isel_points(self, dim='points', **indexers): @@ -1092,22 +1092,26 @@ def expand_dims(self, dim, axis=None): ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - def set_index(self, append=False, inplace=False, **indexes): + def set_index(self, indexes=None, append=False, inplace=False, + **indexes_kwargs): """Set DataArray (multi-)indexes using one or more existing coordinates. Parameters ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). inplace : bool, optional If True, set new index(es) in-place. Otherwise, return a new DataArray object. - **indexes : {dim: index, ...} - Keyword arguments with names matching dimensions and values given - by (lists of) the names of existing coordinates or variables to set - as new (multi-)index. + **indexes_kwargs: optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. Returns ------- @@ -1118,6 +1122,7 @@ def set_index(self, append=False, inplace=False, **indexes): -------- DataArray.reset_index """ + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') coords, _ = merge_indexes(indexes, self._coords, set(), append=append) if inplace: self._coords = coords @@ -1156,18 +1161,22 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): else: return self._replace(coords=coords) - def reorder_levels(self, inplace=False, **dim_order): + def reorder_levels(self, dim_order=None, inplace=False, + **dim_order_kwargs): """Rearrange index levels using input order. Parameters ---------- + dim_order : optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. inplace : bool, optional If True, modify the dataarray in-place. Otherwise, return a new DataArray object. - **dim_order : optional - Keyword arguments with names matching dimensions and values given - by lists representing new level orders. Every given dimension - must have a multi-index. + **dim_order_kwargs: optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. Returns ------- @@ -1175,6 +1184,8 @@ def reorder_levels(self, inplace=False, **dim_order): Another dataarray, with this dataarray's data but replaced coordinates. """ + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, + 'reorder_levels') replace_coords = {} for dim, order in dim_order.items(): coord = self._coords[dim] @@ -1190,7 +1201,7 @@ def reorder_levels(self, inplace=False, **dim_order): else: return self._replace(coords=coords) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1199,9 +1210,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of the form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1230,7 +1244,7 @@ def stack(self, **dimensions): -------- DataArray.unstack """ - ds = self._to_temp_dataset().stack(**dimensions) + ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) return self._from_temp_dataset(ds) def unstack(self, dim): @@ -1978,7 +1992,7 @@ def diff(self, dim, n=1, label='upper'): ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """Shift this array by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. Values shifted from @@ -1987,10 +2001,13 @@ def shift(self, **shifts): Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : Mapping with the form of {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- @@ -2012,8 +2029,8 @@ def shift(self, **shifts): Coordinates: * x (x) int64 0 1 2 """ - variable = self.variable.shift(**shifts) - return self._replace(variable) + ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs) + return self._from_temp_dataset(ds) def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this array by an offset along one or more dimensions. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 37544aca372..597b681bf65 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1759,8 +1759,8 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True): align """ indexers = alignment.reindex_like_indexers(self, other) - return self.reindex(method=method, copy=copy, tolerance=tolerance, - **indexers) + return self.reindex(indexers=indexers, method=method, copy=copy, + tolerance=tolerance) def reindex(self, indexers=None, method=None, tolerance=None, copy=True, **indexers_kwargs): @@ -1809,7 +1809,7 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, pandas.Index.get_indexer """ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, - 'reindex') + 'reindex') bad_dims = [d for d in indexers if d not in self.dims] if bad_dims: @@ -2144,22 +2144,26 @@ def expand_dims(self, dim, axis=None): return self._replace_vars_and_dims(variables, self._coord_names) - def set_index(self, append=False, inplace=False, **indexes): + def set_index(self, indexes=None, append=False, inplace=False, + **indexes_kwargs): """Set Dataset (multi-)indexes using one or more existing coordinates or variables. Parameters ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). inplace : bool, optional If True, set new index(es) in-place. Otherwise, return a new Dataset object. - **indexes : {dim: index, ...} - Keyword arguments with names matching dimensions and values given - by (lists of) the names of existing coordinates or variables to set - as new (multi-)index. + **indexes_kwargs: optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. Returns ------- @@ -2170,6 +2174,7 @@ def set_index(self, append=False, inplace=False, **indexes): -------- Dataset.reset_index """ + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') variables, coord_names = merge_indexes(indexes, self._variables, self._coord_names, append=append) @@ -2206,18 +2211,22 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): return self._replace_vars_and_dims(variables, coord_names=coord_names, inplace=inplace) - def reorder_levels(self, inplace=False, **dim_order): + def reorder_levels(self, dim_order=None, inplace=False, + **dim_order_kwargs): """Rearrange index levels using input order. Parameters ---------- + dim_order : optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. inplace : bool, optional If True, modify the dataset in-place. Otherwise, return a new DataArray object. - **dim_order : optional - Keyword arguments with names matching dimensions and values given - by lists representing new level orders. Every given dimension - must have a multi-index. + **dim_order_kwargs: optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. Returns ------- @@ -2225,6 +2234,8 @@ def reorder_levels(self, inplace=False, **dim_order): Another dataset, with this dataset's data but replaced coordinates. """ + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, + 'reorder_levels') replace_variables = {} for dim, order in dim_order.items(): coord = self._variables[dim] @@ -2267,7 +2278,7 @@ def _stack_once(self, dims, new_dim): return self._replace_vars_and_dims(variables, coord_names) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -2276,9 +2287,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of the form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -2289,6 +2303,8 @@ def stack(self, **dimensions): -------- Dataset.unstack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'stack') result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -2329,7 +2345,7 @@ def unstack(self, dim): if index.equals(full_idx): obj = self else: - obj = self.reindex(copy=False, **{dim: full_idx}) + obj = self.reindex({dim: full_idx}, copy=False) new_dim_names = index.names new_dim_sizes = [lev.size for lev in index.levels] @@ -2339,7 +2355,7 @@ def unstack(self, dim): if name != dim: if dim in var.dims: new_dims = OrderedDict(zip(new_dim_names, new_dim_sizes)) - variables[name] = var.unstack(**{dim: new_dims}) + variables[name] = var.unstack({dim: new_dims}) else: variables[name] = var @@ -2578,7 +2594,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): else: raise TypeError('must specify how or thresh') - return self.isel(**{dim: mask}) + return self.isel({dim: mask}) def fillna(self, value): """Fill missing values in this object. @@ -2843,17 +2859,20 @@ def apply(self, func, keep_attrs=False, args=(), **kwargs): attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) - def assign(self, **kwargs): + def assign(self, variables=None, **variables_kwargs): """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. Parameters ---------- - kwargs : keyword, value pairs - keywords are the variables names. If the values are callable, they - are computed on the Dataset and assigned to new data variables. If - the values are not callable, (e.g. a DataArray, scalar, or array), - they are simply assigned. + variables : mapping, value pairs + Mapping from variables names to the new values. If the new values + are callable, they are computed on the Dataset and assigned to new + data variables. If the values are not callable, (e.g. a DataArray, + scalar, or array), they are simply assigned. + **variables_kwargs: + The keyword arguments form of ``variables``. + One of variables or variables_kwarg must be provided. Returns ------- @@ -2873,9 +2892,10 @@ def assign(self, **kwargs): -------- pandas.DataFrame.assign """ + variables = either_dict_or_kwargs(variables, variables_kwargs, 'assign') data = self.copy() # do all calculations first... - results = data._calc_assign_results(kwargs) + results = data._calc_assign_results(variables) # ... and then assign data.update(results) return data @@ -3310,7 +3330,7 @@ def diff(self, dim, n=1, label='upper'): else: return difference - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -3318,10 +3338,13 @@ def shift(self, **shifts): Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : Mapping with the form of {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- @@ -3345,6 +3368,7 @@ def shift(self, **shifts): Data variables: foo (x) object nan nan 'a' 'b' 'c' """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') invalid = [k for k in shifts if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 24ed280b19e..883dbb34dff 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -44,7 +44,7 @@ class Rolling(object): _attributes = ['window', 'min_periods', 'center', 'dim'] - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object. @@ -52,18 +52,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset or DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -115,7 +115,7 @@ def __len__(self): class DataArrayRolling(Rolling): - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -125,18 +125,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -149,8 +149,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): Dataset.rolling Dataset.groupby """ - super(DataArrayRolling, self).__init__(obj, min_periods=min_periods, - center=center, **windows) + super(DataArrayRolling, self).__init__( + obj, windows, min_periods=min_periods, center=center) self.window_labels = self.obj[self.dim] @@ -321,7 +321,7 @@ def wrapped_func(self, **kwargs): class DatasetRolling(Rolling): - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -331,18 +331,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -355,8 +355,7 @@ def __init__(self, obj, min_periods=None, center=False, **windows): Dataset.groupby DataArray.groupby """ - super(DatasetRolling, self).__init__(obj, - min_periods, center, **windows) + super(DatasetRolling, self).__init__(obj, windows, min_periods, center) if self.dim not in self.obj.dims: raise KeyError(self.dim) # Keep each Rolling object as an OrderedDict @@ -364,8 +363,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim if self.dim in da.dims: - self.rollings[key] = DataArrayRolling(da, min_periods, - center, **windows) + self.rollings[key] = DataArrayRolling( + da, windows, min_periods, center) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d9772407b82..d82fd6fb7ea 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -877,7 +877,7 @@ def squeeze(self, dim=None): numpy.squeeze """ dims = common.get_squeeze_dims(self, dim) - return self.isel(**{d: 0 for d in dims}) + return self.isel({d: 0 for d in dims}) def _shift_one_dim(self, dim, count): axis = self.get_axis_num(dim) @@ -919,36 +919,46 @@ def _shift_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """ Return a new Variable with shifted data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but shifted data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') result = self for dim, count in shifts.items(): result = result._shift_one_dim(dim, count) return result - def pad_with_fill_value(self, fill_value=dtypes.NA, **pad_widths): + def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, + **pad_widths_kwargs): """ Return a new Variable with paddings. Parameters ---------- - **pad_width: keyword arguments of the form {dim: (before, after)} + pad_width: Mapping of the form {dim: (before, after)} Number of values padded to the edges of each dimension. + **pad_widths_kwargs: + Keyword argument for pad_widths """ + pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs, + 'pad') + if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) else: @@ -1009,22 +1019,27 @@ def _roll_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def roll(self, **shifts): + def roll(self, shifts=None, **shifts_kwargs): """ Return a new Variable with rolld data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to roll along each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but rolled data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') + result = self for dim, count in shifts.items(): result = result._roll_one_dim(dim, count) @@ -1142,7 +1157,7 @@ def _stack_once(self, dims, new_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1151,9 +1166,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1164,6 +1182,8 @@ def stack(self, **dimensions): -------- Variable.unstack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'stack') result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -1195,7 +1215,7 @@ def _unstack_once(self, dims, old_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def unstack(self, **dimensions): + def unstack(self, dimensions=None, **dimensions_kwargs): """ Unstack an existing dimension into multiple new dimensions. @@ -1204,9 +1224,12 @@ def unstack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form old_dim={dim1: size1, ...} + dimensions : mapping of the form old_dim={dim1: size1, ...} Names of existing dimensions, and the new dimensions and sizes that they map to. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1217,6 +1240,8 @@ def unstack(self, **dimensions): -------- Variable.stack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'unstack') result = self for old_dim, dims in dimensions.items(): result = result._unstack_once(dims, old_dim) From 8378d3af259d7d1907359fc087dd0a6ca7e5ef17 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Mon, 20 Aug 2018 10:13:15 +0900 Subject: [PATCH 07/51] [MAINT] Avoid using duck typing (#2372) * avoiding duck typing * flake8 * Restore unintended change * Restore original form of _call_possibly_missing_method --- xarray/core/alignment.py | 5 ++++- xarray/core/combine.py | 7 ++++--- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 10 +++++++--- xarray/core/merge.py | 15 ++++++++++----- xarray/core/variable.py | 13 +++---------- xarray/tests/test_backends.py | 1 - xarray/tests/test_variable.py | 16 ---------------- 8 files changed, 29 insertions(+), 40 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index b0d2a49c29f..f82ddef25ba 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -174,11 +174,14 @@ def deep_align(objects, join='inner', copy=True, indexes=None, This function is not public API. """ + from .dataarray import DataArray + from .dataset import Dataset + if indexes is None: indexes = {} def is_alignable(obj): - return hasattr(obj, 'indexes') and hasattr(obj, 'reindex') + return isinstance(obj, (DataArray, Dataset)) positions = [] keys = [] diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 430f0e564d6..f0cc025dc7e 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -125,16 +125,17 @@ def _calc_concat_dim_coord(dim): Infer the dimension name and 1d coordinate variable (if appropriate) for concatenating along the new dimension. """ + from .dataarray import DataArray + if isinstance(dim, basestring): coord = None - elif not hasattr(dim, 'dims'): - # dim is not a DataArray or IndexVariable + elif not isinstance(dim, (DataArray, Variable)): dim_name = getattr(dim, 'name', None) if dim_name is None: dim_name = 'concat_dim' coord = IndexVariable(dim_name, dim) dim = dim_name - elif not hasattr(dim, 'name'): + elif not isinstance(dim, DataArray): coord = as_variable(dim).to_index_variable() dim, = coord.dims else: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 359812f2cc3..373a6a4cc9e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1870,7 +1870,7 @@ def _binary_op(f, reflexive=False, join=None, **ignored_kwargs): def func(self, other): if isinstance(other, (Dataset, groupby.GroupBy)): return NotImplemented - if hasattr(other, 'indexes'): + if isinstance(other, DataArray): align_type = (OPTIONS['arithmetic_join'] if join is None else join) self, other = align(self, other, join=align_type, copy=False) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 597b681bf65..3d02b382921 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2264,7 +2264,7 @@ def _stack_once(self, dims, new_dim): # consider dropping levels that are unused? levels = [self.get_index(dim) for dim in dims] - if hasattr(pd, 'RangeIndex'): + if LooseVersion(pd.__version__) < LooseVersion('0.19.0'): # RangeIndex levels in a MultiIndex are broken for appending in # pandas before v0.19.0 levels = [pd.Int64Index(level) @@ -3175,10 +3175,12 @@ def func(self, *args, **kwargs): def _binary_op(f, reflexive=False, join=None): @functools.wraps(f) def func(self, other): + from .dataarray import DataArray + if isinstance(other, groupby.GroupBy): return NotImplemented align_type = OPTIONS['arithmetic_join'] if join is None else join - if hasattr(other, 'indexes'): + if isinstance(other, (DataArray, Dataset)): self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) @@ -3190,12 +3192,14 @@ def func(self, other): def _inplace_binary_op(f): @functools.wraps(f) def func(self, other): + from .dataarray import DataArray + if isinstance(other, groupby.GroupBy): raise TypeError('in-place operations between a Dataset and ' 'a grouped object are not permitted') # we don't actually modify arrays in-place with in-place Dataset # arithmetic -- this lets us automatically align things - if hasattr(other, 'indexes'): + if isinstance(other, (DataArray, Dataset)): other = other.reindex_like(self, copy=False) g = ops.inplace_to_noninplace_op(f) ds = self._calculate_binary_op(g, other, inplace=True) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f823717a8af..984dd2fa204 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -190,10 +190,13 @@ def expand_variable_dicts(list_of_variable_dicts): an input's values. The values of each ordered dictionary are all xarray.Variable objects. """ + from .dataarray import DataArray + from .dataset import Dataset + var_dicts = [] for variables in list_of_variable_dicts: - if hasattr(variables, 'variables'): # duck-type Dataset + if isinstance(variables, Dataset): sanitized_vars = variables.variables else: # append coords to var_dicts before appending sanitized_vars, @@ -201,7 +204,7 @@ def expand_variable_dicts(list_of_variable_dicts): sanitized_vars = OrderedDict() for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): # use private API for speed coords = var._coords.copy() # explicitly overwritten variables should take precedence @@ -232,17 +235,19 @@ def determine_coords(list_of_variable_dicts): All variable found in the input should appear in either the set of coordinate or non-coordinate names. """ + from .dataarray import DataArray + from .dataset import Dataset + coord_names = set() noncoord_names = set() for variables in list_of_variable_dicts: - if hasattr(variables, 'coords') and hasattr(variables, 'data_vars'): - # duck-type Dataset + if isinstance(variables, Dataset): coord_names.update(variables.coords) noncoord_names.update(variables.data_vars) else: for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): coords = set(var._coords) # use private API for speed # explicitly overwritten variables should take precedence coords.discard(name) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d82fd6fb7ea..33a093d0496 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -64,22 +64,15 @@ def as_variable(obj, name=None): The newly created variable. """ + from .dataarray import DataArray + # TODO: consider extending this method to automatically handle Iris and - # pandas objects. - if hasattr(obj, 'variable'): + if isinstance(obj, DataArray): # extract the primary Variable from DataArrays obj = obj.variable if isinstance(obj, Variable): obj = obj.copy(deep=False) - elif hasattr(obj, 'dims') and (hasattr(obj, 'data') or - hasattr(obj, 'values')): - obj_data = getattr(obj, 'data', None) - if obj_data is None: - obj_data = getattr(obj, 'values') - obj = Variable(obj.dims, obj_data, - getattr(obj, 'attrs', None), - getattr(obj, 'encoding', None)) elif isinstance(obj, tuple): try: obj = Variable(*obj) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e6de50b9dd2..3801225299f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1404,7 +1404,6 @@ def test_chunk_encoding_with_dask(self): with self.roundtrip(ds_chunk4) as actual: self.assertEqual((4,), actual['var1'].encoding['chunks']) - # TODO: remove this failure once syncronized overlapping writes are # supported by xarray ds_chunk4['var1'].encoding.update({'chunks': 5}) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 3db5e6adc4b..a08f7262577 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function -from collections import namedtuple from copy import copy, deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion @@ -938,21 +937,6 @@ def test_as_variable(self): assert not isinstance(ds['x'], Variable) assert isinstance(as_variable(ds['x']), Variable) - FakeVariable = namedtuple('FakeVariable', 'values dims') - fake_xarray = FakeVariable(expected.values, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', 'data dims') - fake_xarray = FakeVariable(expected.data, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', - 'data values dims attrs encoding') - fake_xarray = FakeVariable(expected_extra.data, expected_extra.values, - expected_extra.dims, expected_extra.attrs, - expected_extra.encoding) - assert_identical(expected_extra, as_variable(fake_xarray)) - xarray_tuple = (expected_extra.dims, expected_extra.values, expected_extra.attrs, expected_extra.encoding) assert_identical(expected_extra, as_variable(xarray_tuple)) From 69086b332c6c950587830b266df4e624c2106d89 Mon Sep 17 00:00:00 2001 From: NotSqrt Date: Mon, 20 Aug 2018 18:31:15 +0200 Subject: [PATCH 08/51] Fix maybe_promote (#1953) * Fix maybe_promote With tests for every possible dtype: (numpy docs say `biufcmMOSUV` only) ``` for letter in string.ascii_letters: try: print(letter, np.dtype(letter)) except TypeError as exc: pass ``` * Check issubdtype of floating before timedelta64 In order to hit this branch more often * Improve maybe_promote test --- xarray/core/dtypes.py | 7 +++++-- xarray/tests/test_dtypes.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 7ad44472f06..a2f11728b4d 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -80,6 +80,11 @@ def maybe_promote(dtype): # N.B. these casting rules should match pandas if np.issubdtype(dtype, np.floating): fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64('NaT') elif np.issubdtype(dtype, np.integer): if dtype.itemsize <= 2: dtype = np.float32 @@ -90,8 +95,6 @@ def maybe_promote(dtype): fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): fill_value = np.datetime64('NaT') - elif np.issubdtype(dtype, np.timedelta64): - fill_value = np.timedelta64('NaT') else: dtype = object fill_value = np.nan diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 833df85f8af..292c60b4d05 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -50,3 +50,39 @@ def error(): def test_inf(obj): assert dtypes.INF > obj assert dtypes.NINF < obj + + +@pytest.mark.parametrize("kind, expected", [ + ('a', (np.dtype('O'), 'nan')), # dtype('S') + ('b', (np.float32, 'nan')), # dtype('int8') + ('B', (np.float32, 'nan')), # dtype('uint8') + ('c', (np.dtype('O'), 'nan')), # dtype('S1') + ('D', (np.complex128, '(nan+nanj)')), # dtype('complex128') + ('d', (np.float64, 'nan')), # dtype('float64') + ('e', (np.float16, 'nan')), # dtype('float16') + ('F', (np.complex64, '(nan+nanj)')), # dtype('complex64') + ('f', (np.float32, 'nan')), # dtype('float32') + ('h', (np.float32, 'nan')), # dtype('int16') + ('H', (np.float32, 'nan')), # dtype('uint16') + ('i', (np.float64, 'nan')), # dtype('int32') + ('I', (np.float64, 'nan')), # dtype('uint32') + ('l', (np.float64, 'nan')), # dtype('int64') + ('L', (np.float64, 'nan')), # dtype('uint64') + ('m', (np.timedelta64, 'NaT')), # dtype(' Date: Mon, 27 Aug 2018 18:21:19 -0700 Subject: [PATCH 09/51] =?UTF-8?q?BUG:=20modify=20behavior=20of=20Dataset.f?= =?UTF-8?q?ilter=5Fby=5Fattrs=20to=20match=20netCDF4.Data=E2=80=A6=20(#232?= =?UTF-8?q?2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * BUG: modify behavior of Dataset.filter_by_attrs to match netCDF4.Dataset.get_variables_by_attributes * fix style and add more test for Dataset.filter_by_attrs * update what-new doc with fix for gh2315 --- doc/whats-new.rst | 6 ++++++ xarray/core/dataset.py | 19 ++++++++++++++----- xarray/tests/test_dataset.py | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4725fe74577..1ccf4dee00e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -99,6 +99,12 @@ Bug fixes (:issue:`2341`) By `Keisuke Fujii `_. +- Fixed ``Dataset.filter_by_attrs()`` behavior not matching ``netCDF4.Dataset.get_variables_by_attributes()``. + When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass + all the filters. + (:issue:`2315`) + By `Andrew Barna Date: Mon, 27 Aug 2018 18:48:59 -0700 Subject: [PATCH 10/51] fix typo in uri in the docs (#2386) --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1ccf4dee00e..dd9eb9e48fe 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -103,7 +103,7 @@ Bug fixes When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass all the filters. (:issue:`2315`) - By `Andrew Barna `_. .. _whats-new.0.10.8: From ecee9a0fe01db13bce1e234519614aeed53a7f07 Mon Sep 17 00:00:00 2001 From: Ray Bell Date: Tue, 28 Aug 2018 14:11:55 -0400 Subject: [PATCH 11/51] DOC: move xskillscore to 'Extend xarray capabilities' (#2387) --- doc/related-projects.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 9b75d0e1b3e..714fbf98d7c 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -35,7 +35,6 @@ Geosciences - `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. - `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. - `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. -- `xskillscore `_: Metrics for verifying forecasts. Machine Learning ~~~~~~~~~~~~~~~~ @@ -52,6 +51,7 @@ Extend xarray capabilities - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. - `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. +- `xskillscore `_: Metrics for verifying forecasts. - `xyzpy `_: Easily generate high dimensional data, including parallelization. Visualization From e5ae4088f3512eb805b13ea138087350b8180d69 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 31 Aug 2018 09:23:18 -0700 Subject: [PATCH 12/51] Mark test_equals_all_dtypes as xfail again (#2393) --- xarray/tests/test_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index a08f7262577..904940cbbf6 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1651,7 +1651,7 @@ def test_getitem_1d_fancy(self): def test_equals_all_dtypes(self): import dask - if '0.18.2' <= LooseVersion(dask.__version__) < '0.18.3': + if '0.18.2' <= LooseVersion(dask.__version__) < '0.19.1': pytest.xfail('https://github.com/pydata/xarray/issues/2318') super(TestVariableWithDask, self).test_equals_all_dtypes() From a3ca579c3c6996a44440c7b0f5f68932b5a1c46d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 4 Sep 2018 21:09:23 +0530 Subject: [PATCH 13/51] Silence some warnings. (#2328) * Make sure dask tests work with dask=0.16 * Silence some pnetcdf warnings. * fix sel_points, isel_points fancy indexing tests * Revert to using xr.ufuncs * Fix overflow/underflow warnings in interpolate_na These were being triggered by casting datetime64[ns] to float32. We now rescale the co-ordinate before interpolating, except for nearest-neighbour interpolation. The rescaling can change the nearest neighbour, and so is avoided in this case to maintain pandas compatibility. * Rescale datetime for interp() too. * Better rescaling. * Revert "Better rescaling." This reverts commit 76f988f594aea23d3acde1d603db5460c9010c1e. * Revert "Rescale datetime for interp() too." This reverts commit 9ac15ef677c4d21230b7aab40a65c1d7b0530ece. * Revert "Fix overflow/underflow warnings in interpolate_na" This reverts commit 1f1ec52707f8b2349461e41b68a7bc3918deb9f1. * Silence overflow/underflow/invalid value warnings. * Silence a bottleneck warning. * Revert "Silence a bottleneck warning." This reverts commit b9851275fdccd4c1cf8e662bffd5b1353b4ea048. * Dask: change from attribute check to version check. * Maybe this fixes python 2 failure? --- xarray/coding/times.py | 7 +++-- xarray/core/formatting.py | 2 +- xarray/core/missing.py | 19 +++++++----- xarray/plot/plot.py | 7 ++++- xarray/plot/utils.py | 4 ++- xarray/tests/test_backends.py | 49 ++++++++++++++----------------- xarray/tests/test_coding_times.py | 3 +- xarray/tests/test_dask.py | 14 +++++++-- xarray/tests/test_dataarray.py | 5 +++- xarray/tests/test_dataset.py | 4 +++ xarray/tests/test_missing.py | 16 +++++----- xarray/tests/test_plot.py | 10 +++++++ 12 files changed, 89 insertions(+), 51 deletions(-) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index d946e2ed378..6edbedce54c 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -183,8 +183,11 @@ def decode_cf_datetime(num_dates, units, calendar=None, # fixes: https://github.com/pydata/pandas/issues/14068 # these lines check if the the lowest or the highest value in dates # cause an OutOfBoundsDatetime (Overflow) error - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'invalid value encountered', + RuntimeWarning) + pd.to_timedelta(flat_num_dates.min(), delta) + ref_date + pd.to_timedelta(flat_num_dates.max(), delta) + ref_date # Cast input dates to integers of nanoseconds because `pd.to_datetime` # works much faster when dealing with integers diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 65f3c91ca26..042c8c5324d 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -183,7 +183,7 @@ def format_items(x): day_part = (x[~pd.isnull(x)] .astype('timedelta64[D]') .astype('timedelta64[ns]')) - time_needed = x != day_part + time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, 'ns') if np.logical_not(day_needed).all(): timedelta_format = 'time' diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 232fa185c07..90aa4ffaeda 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -3,6 +3,8 @@ from collections import Iterable from functools import partial +import warnings + import numpy as np import pandas as pd @@ -207,13 +209,16 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, interp_class, kwargs = _get_interpolator(method, **kwargs) interpolator = partial(func_interpolate_na, interp_class, **kwargs) - arr = apply_ufunc(interpolator, index, self, - input_core_dims=[[dim], [dim]], - output_core_dims=[[dim]], - output_dtypes=[self.dtype], - dask='parallelized', - vectorize=True, - keep_attrs=True).transpose(*self.dims) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'overflow', RuntimeWarning) + warnings.filterwarnings('ignore', 'invalid value', RuntimeWarning) + arr = apply_ufunc(interpolator, index, self, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], + output_dtypes=[self.dtype], + dask='parallelized', + vectorize=True, + keep_attrs=True).transpose(*self.dims) if limit is not None: arr = arr.where(valids) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 179f41e9e42..0b3ab6f1bde 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -479,9 +479,11 @@ def line(self, *args, **kwargs): def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None + # TODO: remove when min numpy version is bumped to 1.13 # There's a cyclic dependency via DataArray, so we can't import from # xarray.ufuncs in global scope. from xarray.ufuncs import maximum, minimum + # Calculate vmin and vmax automatically for `robust=True` if robust: if vmax is None: @@ -507,7 +509,10 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') - return minimum(maximum(darray, 0), 1) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'xarray.ufuncs', + PendingDeprecationWarning) + return minimum(maximum(darray, 0), 1) def _plot2d(plotfunc): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 1ddb02352be..6221bfe9153 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -213,8 +213,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # Handle discrete levels if levels is not None: if is_scalar(levels): - if user_minmax or levels == 1: + if user_minmax: levels = np.linspace(vmin, vmax, levels) + elif levels == 1: + levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks ticker = mpl.ticker.MaxNLocator(levels - 1) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3801225299f..8b469761ccd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1789,6 +1789,7 @@ def create_store(self): with create_tmp_file() as tmp_file: yield backends.H5NetCDFStore(tmp_file, 'w') + @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py') def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) with self.roundtrip(expected) as actual: @@ -2527,6 +2528,7 @@ class PyNioTestAutocloseTrue(PyNioTest): @requires_pseudonetcdf +@pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') class PseudoNetCDFFormatTest(TestCase): autoclose = True @@ -2658,14 +2660,11 @@ def test_uamiv_format_read(self): """ Open a CAMx file and test data variables """ - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - camxfile = open_example_dataset('example.uamiv', - engine='pseudonetcdf', - autoclose=True, - backend_kwargs={'format': 'uamiv'}) + + camxfile = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + autoclose=True, + backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, dict(units='ppm', long_name='O3'.ljust(16), @@ -2687,17 +2686,14 @@ def test_uamiv_format_mfread(self): """ Open a CAMx file and test data variables """ - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - camxfile = open_example_mfdataset( - ['example.uamiv', - 'example.uamiv'], - engine='pseudonetcdf', - autoclose=True, - concat_dim='TSTEP', - backend_kwargs={'format': 'uamiv'}) + + camxfile = open_example_mfdataset( + ['example.uamiv', + 'example.uamiv'], + engine='pseudonetcdf', + autoclose=True, + concat_dim='TSTEP', + backend_kwargs={'format': 'uamiv'}) data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) data = np.concatenate([data1] * 2, axis=0) @@ -2720,19 +2716,18 @@ def test_uamiv_format_mfread(self): def test_uamiv_format_write(self): fmtkw = {'format': 'uamiv'} - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - expected = open_example_dataset('example.uamiv', - engine='pseudonetcdf', - autoclose=False, - backend_kwargs=fmtkw) + + expected = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + autoclose=False, + backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: assert_identical(expected, actual) + expected.close() + def save(self, dataset, path, **save_kwargs): import PseudoNetCDF as pnc pncf = pnc.PseudoNetCDFFile() diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index e763af4984c..7d3a4930b44 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -538,7 +538,8 @@ def test_cf_datetime_nan(num_dates, units, expected_list): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'All-NaN') actual = coding.times.decode_cf_datetime(num_dates, units) - expected = np.array(expected_list, dtype='datetime64[ns]') + # use pandas because numpy will deprecate timezone-aware conversions + expected = pd.to_datetime(expected_list) assert_array_equal(expected, actual) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f6c47cce8d8..6ca83ab73ab 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -2,6 +2,7 @@ import pickle from textwrap import dedent +from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -24,8 +25,12 @@ class DaskTestCase(TestCase): def assertLazyAnd(self, expected, actual, test): - with dask.set_options(get=dask.get): + + with (dask.config.set(get=dask.get) + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=dask.get)): test(actual, expected) + if isinstance(actual, Dataset): for k, v in actual.variables.items(): if k in actual.dims: @@ -196,11 +201,13 @@ def test_missing_methods(self): except NotImplementedError as err: assert 'dask' in str(err) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_univariate_ufunc(self): u = self.eager_var v = self.lazy_var self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_bivariate_ufunc(self): u = self.eager_var v = self.lazy_var @@ -421,6 +428,7 @@ def duplicate_and_merge(array): actual = duplicate_and_merge(self.lazy_array) self.assertLazyAndEqual(expected, actual) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_ufuncs(self): u = self.eager_array v = self.lazy_array @@ -821,7 +829,9 @@ def test_basic_compute(): dask.multiprocessing.get, dask.local.get_sync, None]: - with dask.set_options(get=get): + with (dask.config.set(get=get) + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=get)): ds.compute() ds.foo.compute() ds.foo.variable.compute() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3619688d091..5d20a6cfec3 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -672,6 +672,7 @@ def test_isel_types(self): assert_identical(da.isel(x=np.array([0], dtype="int64")), da.isel(x=np.array([0]))) + @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_isel_fancy(self): shape = (10, 7, 6) np_array = np.random.random(shape) @@ -845,6 +846,7 @@ def test_isel_drop(self): selected = data.isel(x=0, drop=False) assert_identical(expected, selected) + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_isel_points(self): shape = (10, 5, 6) np_array = np.random.random(shape) @@ -1237,6 +1239,7 @@ def test_reindex_like_no_index(self): ValueError, 'different size for unlabeled'): foo.reindex_like(bar) + @pytest.mark.filterwarnings('ignore:Indexer has dimensions') def test_reindex_regressions(self): # regression test for #279 expected = DataArray(np.random.randn(5), coords=[("time", range(5))]) @@ -1286,7 +1289,7 @@ def test_swap_dims(self): def test_expand_dims_error(self): array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3.0)}, + coords={'x': np.linspace(0.0, 1.0, 3)}, attrs={'key': 'entry'}) with raises_regex(ValueError, 'dim should be str or'): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 101de9fe8c7..068b445c69f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1240,6 +1240,7 @@ def test_isel_drop(self): selected = data.isel(x=0, drop=False) assert_identical(expected, selected) + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_isel_points(self): data = create_test_data() @@ -1317,6 +1318,8 @@ def test_isel_points(self): dim2=stations['dim2s'], dim=np.array([4, 5, 6])) + @pytest.mark.filterwarnings("ignore:Dataset.sel_points") + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_sel_points(self): data = create_test_data() @@ -1347,6 +1350,7 @@ def test_sel_points(self): with pytest.raises(KeyError): data.sel_points(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) + @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_sel_fancy(self): data = create_test_data() diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 5c7e384c789..47224e55473 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -93,14 +93,14 @@ def test_interpolate_pd_compat(): @requires_scipy -def test_scipy_methods_function(): - for method in ['barycentric', 'krog', 'pchip', 'spline', 'akima']: - kwargs = {} - # Note: Pandas does some wacky things with these methods and the full - # integration tests wont work. - da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) - actual = da.interpolate_na(method=method, dim='time', **kwargs) - assert (da.count('time') <= actual.count('time')).all() +@pytest.mark.parametrize('method', ['barycentric', 'krog', + 'pchip', 'spline', 'akima']) +def test_scipy_methods_function(method): + # Note: Pandas does some wacky things with these methods and the full + # integration tests wont work. + da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) + actual = da.interpolate_na(method=method, dim='time') + assert (da.count('time') <= actual.count('time')).all() @requires_scipy diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 4e5ea8fc623..e7caf3d6ca2 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -267,6 +267,7 @@ def test_datetime_dimension(self): assert ax.has_data() @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -328,6 +329,7 @@ def test_plot_size(self): self.darray.plot(aspect=1) @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -775,10 +777,13 @@ def test_plot_nans(self): clim2 = self.plotfunc(x2).get_clim() assert clim1 == clim2 + @pytest.mark.filterwarnings('ignore::UserWarning') + @pytest.mark.filterwarnings('ignore:invalid value encountered') def test_can_plot_all_nans(self): # regression test for issue #1780 self.plotfunc(DataArray(np.full((2, 2), np.nan))) + @pytest.mark.filterwarnings('ignore: Attempting to set') def test_can_plot_axis_size_one(self): if self.plotfunc.__name__ not in ('contour', 'contourf'): self.plotfunc(DataArray(np.ones((1, 1)))) @@ -970,6 +975,7 @@ def test_2d_function_and_method_signature_same(self): del func_sig['darray'] assert func_sig == method_sig + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -1001,6 +1007,7 @@ def test_convenient_facetgrid(self): else: assert '' == ax.get_xlabel() + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -1279,6 +1286,7 @@ def test_imshow_rgb_values_in_valid_range(self): assert out.dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha + @pytest.mark.filterwarnings('ignore:Several dimensions of this array') def test_regression_rgb_imshow_dim_size_one(self): # Regression: https://github.com/pydata/xarray/issues/1966 da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) @@ -1511,6 +1519,7 @@ def test_facetgrid_polar(self): sharey=False) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetGrid4d(PlotTestCase): def setUp(self): a = easy_array((10, 15, 3, 2)) @@ -1538,6 +1547,7 @@ def test_default_labels(self): assert substring_in_axes(label, ax) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetedLinePlots(PlotTestCase): def setUp(self): self.darray = DataArray(np.random.randn(10, 6, 3, 4), From fc9ef81dbda163348316a9014bc44e7dae93a5ed Mon Sep 17 00:00:00 2001 From: Julius Busecke Date: Wed, 5 Sep 2018 17:17:22 +0200 Subject: [PATCH 14/51] add options for nondivergent and divergent cmap (#2397) * add options for nondivergent and divergent cmap * Update test_plot.py * renamed cmap options * Stickler fix * Another sticker fix * Additional tests and credits * Fix merge error in whats-new.rst * Fixed explicit cmap test * Update docstring Specify that colormap need not be matplotlib colormap * update docstring * keep stickler happy --- doc/whats-new.rst | 7 ++++++- xarray/core/options.py | 11 ++++++++++- xarray/plot/utils.py | 5 +++-- xarray/tests/test_plot.py | 16 ++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index dd9eb9e48fe..97794125665 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,11 @@ Documentation Enhancements ~~~~~~~~~~~~ +- Default colormap for sequential and divergent data can now be set via + :py:func:`~xarray.set_options()` + (:issue:`2394`) + By `Julius Busecke `_. + - min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`, :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and :py:meth:`~xarray.Dataset.prod`. @@ -84,7 +89,7 @@ Bug fixes - Tests can be run in parallel with pytest-xdist By `Tony Tung `_. -- Follow up the renamings in dask; from dask.ghost to dask.overlap +- Follow up the renamings in dask; from dask.ghost to dask.overlap By `Keisuke Fujii `_. - Now raises a ValueError when there is a conflict between dimension names and diff --git a/xarray/core/options.py b/xarray/core/options.py index 48d4567fc99..a6118f02ed3 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -3,7 +3,9 @@ OPTIONS = { 'display_width': 80, 'arithmetic_join': 'inner', - 'enable_cftimeindex': False + 'enable_cftimeindex': False, + 'cmap_sequential': 'viridis', + 'cmap_divergent': 'RdBu_r', } @@ -19,6 +21,13 @@ class set_options(object): - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` for time indexes with non-standard calendars or dates outside the Timestamp-valid range. Default: ``False``. + - ``cmap_sequential``: colormap to use for nondivergent data plots. + Default: ``viridis``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + - ``cmap_divergent``: colormap to use for divergent data plots. + Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + You can use ``set_options`` either as a context manager: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 6221bfe9153..9af0624dbfc 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -7,6 +7,7 @@ from ..core.pycompat import basestring from ..core.utils import is_scalar +from ..core.options import OPTIONS ROBUST_PERCENTILE = 2.0 @@ -206,9 +207,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # Choose default colormaps if not provided if cmap is None: if divergent: - cmap = "RdBu_r" + cmap = OPTIONS['cmap_divergent'] else: - cmap = "viridis" + cmap = OPTIONS['cmap_sequential'] # Handle discrete levels if levels is not None: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index e7caf3d6ca2..c38ffeff884 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import xarray as xr import pytest import xarray.plot as xplt @@ -472,6 +473,21 @@ def test_center(self): assert cmap_params['levels'] is None assert cmap_params['norm'] is None + def test_cmap_sequential_option(self): + with xr.set_options(cmap_sequential='magma'): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == 'magma' + + def test_cmap_sequential_explicit_option(self): + with xr.set_options(cmap_sequential=mpl.cm.magma): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == mpl.cm.magma + + def test_cmap_divergent_option(self): + with xr.set_options(cmap_divergent='magma'): + cmap_params = _determine_cmap_params(self.data, center=0.5) + assert cmap_params['cmap'] == 'magma' + def test_nan_inf_are_ignored(self): cmap_params1 = _determine_cmap_params(self.data) data = self.data From 795a7bf26b6b4a6558c13b64864c4b5e0ea79016 Mon Sep 17 00:00:00 2001 From: Stephane Raynaud Date: Wed, 5 Sep 2018 17:18:45 +0200 Subject: [PATCH 15/51] BUG: Fix cdms2 related tests (#2332) (#2400) * Add curvilinear grid support to to_cdms2 and fix mask bug * fix indentation in to_cdms2 * Add generic unstructured grid support to _from_cdms2 and to_cdms2 * Fix indentation in from_cdms2 * Fix indentation in from_cdms2 * Split cdms2 unit tests and use OrderedDict * BUG: Fix #2332 about cdms2 tests * BUG: Fix #2332 about cdms2 tests --- xarray/tests/test_dataarray.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5d20a6cfec3..29ddd40ce25 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2918,7 +2918,6 @@ def test_to_masked_array(self): ma = da.to_masked_array() assert len(ma.mask) == N - @pytest.mark.xfail # GH:2332 TODO fix this in upstream? def test_to_and_from_cdms2_classic(self): """Classic with 1D axes""" pytest.importorskip('cdms2') @@ -2931,7 +2930,7 @@ def test_to_and_from_cdms2_classic(self): expected_coords = [IndexVariable('distance', [-2, 2]), IndexVariable('time', [0, 1, 2])] actual = original.to_cdms2() - assert_array_equal(actual, original) + assert_array_equal(actual.asma(), original) assert actual.id == original.name self.assertItemsEqual(actual.getAxisIds(), original.dims) for axis, coord in zip(actual.getAxisList(), expected_coords): @@ -2953,7 +2952,6 @@ def test_to_and_from_cdms2_classic(self): assert_array_equal(original.coords[coord_name], back.coords[coord_name]) - @pytest.mark.xfail # GH:2332 TODO fix this in upstream? def test_to_and_from_cdms2_sgrid(self): """Curvilinear (structured) grid @@ -2971,8 +2969,10 @@ def test_to_and_from_cdms2_sgrid(self): name='sst') actual = original.to_cdms2() self.assertItemsEqual(actual.getAxisIds(), original.dims) - assert_array_equal(original.coords['lon'], actual.getLongitude()) - assert_array_equal(original.coords['lat'], actual.getLatitude()) + assert_array_equal(original.coords['lon'], + actual.getLongitude().asma()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().asma()) back = from_cdms2(actual) self.assertItemsEqual(original.dims, back.dims) @@ -2980,7 +2980,6 @@ def test_to_and_from_cdms2_sgrid(self): assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) - @pytest.mark.xfail # GH:2332 TODO fix this in upstream? def test_to_and_from_cdms2_ugrid(self): """Unstructured grid""" pytest.importorskip('cdms2') @@ -2992,8 +2991,10 @@ def test_to_and_from_cdms2_ugrid(self): coords={'lon': lon, 'lat': lat, 'cell': cell}) actual = original.to_cdms2() self.assertItemsEqual(actual.getAxisIds(), original.dims) - assert_array_equal(original.coords['lon'], actual.getLongitude()) - assert_array_equal(original.coords['lat'], actual.getLatitude()) + assert_array_equal(original.coords['lon'], + actual.getLongitude().getValue()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().getValue()) back = from_cdms2(actual) self.assertItemsEqual(original.dims, back.dims) From 73f5b02a42a4003815d2bfc91e658195f5050be1 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Wed, 5 Sep 2018 11:19:06 -0400 Subject: [PATCH 16/51] Make `dim` optional on unstack (#2375) * Making dim an optional input for unstack method * Adding tests where dim is not explicitly set * Switched to using first MultiIndex in dim * Fixing too long line * Unstack along all MultiIndexes and accept *dims * Making dim accept interable or string or None * Responding to comments - if no multi-index, return object as is * Added section to whats-new * Pep8 and returing a copy rather than mutating original * linting * Adding back in error if non-MultiIndex is passed as arg * Reworking logic --- doc/whats-new.rst | 3 ++ xarray/core/dataarray.py | 30 ++++++++++++-- xarray/core/dataset.py | 74 +++++++++++++++++++++------------- xarray/tests/test_dataarray.py | 14 +++++++ xarray/tests/test_dataset.py | 9 +++-- 5 files changed, 94 insertions(+), 36 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 97794125665..98fb955ac48 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -73,6 +73,9 @@ Enhancements (:issue:`1875`) By `Andrew Huang `_. +- You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset. + By `Julia Signell `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 373a6a4cc9e..ae3758f0bbd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1247,23 +1247,45 @@ def stack(self, dimensions=None, **dimensions_kwargs): ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) return self._from_temp_dataset(ds) - def unstack(self, dim): + def unstack(self, dim=None): """ - Unstack an existing dimension corresponding to a MultiIndex into + Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. New dimensions will be added at the end. Parameters ---------- - dim : str - Name of the existing dimension to unstack. + dim : str or sequence of str, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. Returns ------- unstacked : DataArray Array with unstacked data. + Examples + -------- + + >>> arr = DataArray(np.arange(6).reshape(2, 3), + ... coords=[('x', ['a', 'b']), ('y', [0, 1, 2])]) + >>> arr + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) |S1 'a' 'b' + * y (y) int64 0 1 2 + >>> stacked = arr.stack(z=('x', 'y')) + >>> stacked.indexes['z'] + MultiIndex(levels=[[u'a', u'b'], [0, 1, 2]], + labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + names=[u'x', u'y']) + >>> roundtripped = stacked.unstack() + >>> arr.identical(roundtripped) + True + See also -------- DataArray.stack diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 06b258ae261..e98495e71fb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2310,35 +2310,8 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def unstack(self, dim): - """ - Unstack an existing dimension corresponding to a MultiIndex into - multiple new dimensions. - - New dimensions will be added at the end. - - Parameters - ---------- - dim : str - Name of the existing dimension to unstack. - - Returns - ------- - unstacked : Dataset - Dataset with unstacked data. - - See also - -------- - Dataset.stack - """ - if dim not in self.dims: - raise ValueError('invalid dimension: %s' % dim) - + def _unstack_once(self, dim): index = self.get_index(dim) - if not isinstance(index, pd.MultiIndex): - raise ValueError('cannot unstack a dimension that does not have ' - 'a MultiIndex') - full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) # take a shortcut in case the MultiIndex was not modified. @@ -2366,6 +2339,51 @@ def unstack(self, dim): return self._replace_vars_and_dims(variables, coord_names) + def unstack(self, dim=None): + """ + Unstack existing dimensions corresponding to MultiIndexes into + multiple new dimensions. + + New dimensions will be added at the end. + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. + + Returns + ------- + unstacked : Dataset + Dataset with unstacked data. + + See also + -------- + Dataset.stack + """ + + if dim is None: + dims = [d for d in self.dims if isinstance(self.get_index(d), + pd.MultiIndex)] + else: + dims = [dim] if isinstance(dim, basestring) else dim + + missing_dims = [d for d in dims if d not in self.dims] + if missing_dims: + raise ValueError('Dataset does not contain the dimensions: %s' + % missing_dims) + + non_multi_dims = [d for d in dims if not + isinstance(self.get_index(d), pd.MultiIndex)] + if non_multi_dims: + raise ValueError('cannot unstack dimensions that do not ' + 'have a MultiIndex: %s' % non_multi_dims) + + result = self.copy(deep=False) + for dim in dims: + result = result._unstack_once(dim) + return result + def update(self, other, inplace=True): """Update this dataset's variables with those from another dataset. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 29ddd40ce25..a4562894583 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1663,9 +1663,23 @@ def test_dataset_math(self): def test_stack_unstack(self): orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2}) + assert_identical(orig, orig.unstack()) + actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) assert_identical(orig, actual) + dims = ['a', 'b', 'c', 'd', 'e'] + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + stacked = orig.stack(ab=['a', 'b'], cd=['c', 'd']) + + unstacked = stacked.unstack(['ab', 'cd']) + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + + unstacked = stacked.unstack() + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 orig = DataArray(np.random.rand(3, 4), dims=('y', 'x'), diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 068b445c69f..d22d8470dc6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2107,14 +2107,15 @@ def test_unstack(self): expected = Dataset({'b': (('x', 'y'), [[0, 1], [2, 3]]), 'x': [0, 1], 'y': ['a', 'b']}) - actual = ds.unstack('z') - assert_identical(actual, expected) + for dim in ['z', ['z'], None]: + actual = ds.unstack(dim) + assert_identical(actual, expected) def test_unstack_errors(self): ds = Dataset({'x': [1, 2, 3]}) - with raises_regex(ValueError, 'invalid dimension'): + with raises_regex(ValueError, 'does not contain the dimensions'): ds.unstack('foo') - with raises_regex(ValueError, 'does not have a MultiIndex'): + with raises_regex(ValueError, 'do not have a MultiIndex'): ds.unstack('x') def test_stack_unstack_fast(self): From 66a8f8dd7f5a2997ff614f3966d1951587915e7e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 6 Sep 2018 09:20:37 +0530 Subject: [PATCH 17/51] plot.imshow now obeys 'origin' kwarg. (#2396) Fixes #2379 --- doc/whats-new.rst | 4 ++++ xarray/plot/plot.py | 15 +++++++++------ xarray/tests/test_plot.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 98fb955ac48..881ae52cdeb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -79,6 +79,10 @@ Enhancements Bug fixes ~~~~~~~~~ +- ``xarray.plot.imshow()`` correctly uses the ``origin`` argument. + (:issue:`2379`) + By `Deepak Cherian `_. + - Fixed ``DataArray.to_iris()`` failure while creating ``DimCoord`` by falling back to creating ``AuxCoord``. Fixed dependency on ``var_name`` attribute being set. diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0b3ab6f1bde..10fca44b417 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -624,7 +624,7 @@ def _plot2d(plotfunc): @functools.wraps(plotfunc) def newplotfunc(darray, x=None, y=None, figsize=None, size=None, aspect=None, ax=None, row=None, col=None, - col_wrap=None, xincrease=True, yincrease=True, + col_wrap=None, xincrease=None, yincrease=None, add_colorbar=None, add_labels=True, vmin=None, vmax=None, cmap=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, colors=None, @@ -794,7 +794,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, @functools.wraps(newplotfunc) def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, aspect=None, ax=None, row=None, col=None, col_wrap=None, - xincrease=True, yincrease=True, add_colorbar=None, + xincrease=None, yincrease=None, add_colorbar=None, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, subplot_kws=None, @@ -862,10 +862,8 @@ def imshow(x, y, z, ax, **kwargs): left, right = x[0] - xstep, x[-1] + xstep bottom, top = y[-1] + ystep, y[0] - ystep - defaults = {'extent': [left, right, bottom, top], - 'origin': 'upper', - 'interpolation': 'nearest', - } + defaults = {'origin': 'upper', + 'interpolation': 'nearest'} if not hasattr(ax, 'projection'): # not for cartopy geoaxes @@ -874,6 +872,11 @@ def imshow(x, y, z, ax, **kwargs): # Allow user to override these defaults defaults.update(kwargs) + if defaults['origin'] == 'upper': + defaults['extent'] = [left, right, bottom, top] + else: + defaults['extent'] = [left, right, top, bottom] + if z.ndim == 3: # matplotlib imshow uses black for missing data, but Xarray makes # missing data transparent. We therefore add an alpha channel if diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c38ffeff884..15cb6af5fb1 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1308,6 +1308,17 @@ def test_regression_rgb_imshow_dim_size_one(self): da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) da.plot.imshow() + def test_imshow_origin_kwarg(self): + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + da.plot.imshow(origin='upper') + assert plt.xlim()[0] < 0 + assert plt.ylim()[1] < 0 + + plt.clf() + da.plot.imshow(origin='lower') + assert plt.xlim()[0] < 0 + assert plt.ylim()[0] < 0 + class TestFacetGrid(PlotTestCase): def setUp(self): From 59bf7a7c13d8b01fd9600cb76c82b35b465f5707 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Mon, 10 Sep 2018 19:14:17 -0700 Subject: [PATCH 18/51] add some blurbs about numfocus sponsorship to docs (#2403) * add some blurbs about numfocus sponsorship to docs * add numfocus to history blurb --- README.rst | 20 ++++++++++++++++++-- doc/index.rst | 24 +++++++++++++++++++----- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/README.rst b/README.rst index 94beea1dba4..7355848eca4 100644 --- a/README.rst +++ b/README.rst @@ -15,6 +15,8 @@ xarray: N-D labeled arrays and datasets :target: https://zenodo.org/badge/latestdoi/13221727 .. image:: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat :target: http://pandas.pydata.org/speed/xarray/ +.. image:: https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A + :target: http://numfocus.org **xarray** (formerly **xray**) is an open source project and Python package that aims to bring the labeled data power of pandas_ to the physical sciences, by providing @@ -103,20 +105,34 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray +NumFOCUS +-------- + +.. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png + :scale: 50 % + +Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a +[donation](https://www.flipcause.com/secure/cause_pdetails/MjE3OQ==) +to support our efforts. + History ------- xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org License ------- -Copyright 2014-2017, xarray Developers +Copyright 2014-2018, xarray Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/doc/index.rst b/doc/index.rst index e66c448f780..6c1a8519507 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -120,12 +120,17 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray -License -------- +NumFOCUS +-------- -xarray is available under the open source `Apache License`__. +.. image:: _static/numfocus_logo.png + :scale: 50 % -__ http://www.apache.org/licenses/LICENSE-2.0.html +Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a +[donation](https://www.flipcause.com/secure/cause_pdetails/MjE3OQ==) +to support our efforts. History ------- @@ -133,6 +138,15 @@ History xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org + +License +------- + +xarray is available under the open source `Apache License`__. + +__ http://www.apache.org/licenses/LICENSE-2.0.html From 4de8dbc3b1de461c0c9d3b002e55d60b46d2e6d2 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Mon, 10 Sep 2018 22:13:50 -0700 Subject: [PATCH 19/51] Numfocus (#2409) * add some blurbs about numfocus sponsorship to docs * add numfocus to history blurb * add missing logo file * markdown to rst in readme --- README.rst | 7 ++++--- doc/_static/numfocus_logo.png | Bin 0 -> 24992 bytes 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 doc/_static/numfocus_logo.png diff --git a/README.rst b/README.rst index 7355848eca4..12650b1db1b 100644 --- a/README.rst +++ b/README.rst @@ -109,14 +109,15 @@ NumFOCUS -------- .. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png - :scale: 50 % + :scale: 25 % Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated to supporting the open source scientific computing community. If you like -Xarray and want to support our mission, please consider making a -[donation](https://www.flipcause.com/secure/cause_pdetails/MjE3OQ==) +Xarray and want to support our mission, please consider making a donation_ to support our efforts. +.. _donation: https://www.flipcause.com/secure/cause_pdetails/MjE3OQ== + History ------- diff --git a/doc/_static/numfocus_logo.png b/doc/_static/numfocus_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..af3c84209e06d9d4b106be987ca0c0013ff697a2 GIT binary patch literal 24992 zcmdSAq{|T!lB1NC96jmLAq~=? zAaZY?@9%zi{{irYwltaY%b&fH05?Bmth|5p>U(`Ek=W;vbWT$}OJJWr}J=2y3NL?%p~5SU}|msfRR z7+d7Sh$f2qmDhFFdX`zg2&yCh&kheMDv8TmBTZOlt+)J;q7mGT{tvhKkTGDRh|kWl z##>6r#7Jvqf6c8|NCR*)d4)vXCw-<4mj5%L18vp%vbkhkHHIAdmSTbTv=So;u8lOm zGi6C?MYwtu@sCf|&`Bz4%cGiPqVs>2H&*BvIqui?~6bAe?yadCaTrRwk|piB^NCX;LI}&X_%T zlZqxcnF+D(yCdu+_@I3!=Mj~6V#(hqa;x&MbSaxZz_1h950tM?gWQ{~9&Hgm#2Uf> zCJ)%@1_I|GTW`Umlqyy${`yWw{AwmEp#obLdRykmJP+b3MV%E)xK8A(X;!q!+6y7K ztuBs}24+ReEit5MAr$%in{*M6Q5&gkAl$ z=G#Vgop;H({N^o2WXKn(2X&T)2WteC`w=Y6>>_p1n0~Nbh{?VKTQBU3>Mcojbx(oC*<1PQ+^lWEgjj9m#TYLSP2iO77H{I!&@J{4 zG6*i9Z2wwLFyjuZ!5^69MESQ-9x%tx{sOKs4(%`Gb}9qc_DBh-FtOpq2Ni&7=^tK5 zUUypE%-_>zx6OhpI;5|Swzx@W8&WtQe+e}u4D*KDaYxz&b|4Y}dFhD1HuKJwUcAOTR$pw+C341VpL(Xdy?x(cOa!#WJ$Z1Xu%x zustH~iwyO9sTzpuA8R>ntmb0*$8)+)34N`VLRG#3(X3B+S-u?dtDCY&3Ib3(#uVr= z+iI|7c(Q23rzj2b%GGi}igg2VT_a`1aJYL<1m|7b0#@!=c~vt-I$_ zy(Vpp=r;R5+0SWJq^VKbWM&lK_!`G?Df?O7LWA+Rp;ItioD zA}lt&2zl7bxD4xZ9vI8a*3V+^lNyLO2pHu@1l1C%pFbY-ZBhwsdr4q@{8n&-;=}O9 zNKCehJ%2P5dZGj(#(v`@%Y;{nsWq&srFwC}?M0KAO*$biKx>Nw`GE1A)@P<} zmQkjmrk#{k=FzpU(_#(VF`RjGe6iA`=tp!YO3A%}zg#0jTJ3*Ct5ArGuYxk(E(9oI z4l*bg&cn}zVta1uN4_$MHEAg=y{Mc2i#SA)iqUb}hT2Cd#KSasD!{RC;t0^|@! z8!Gt*1<2O?6)C0j9wWs4>;oLx@Vir?>L^~VKOL|I7R@UzEMF6#*oYvkTNuf1bec>!a~O`SV8a zK>Z$hh1W}YwU65*oc<6(`A8v`Bj)W@BiXUAM+sMY<{3_CQMALNC>5f@3?X3kWbiPB z-WnLmOwD-73Cxm2c&oeUA>;Vwt^NvavI_WF>F&bfG0EE>nv;|thr`Ka;FrO1TU^2{ zhmVjCQKYpJ?^N#fm5zNrJ|PQ(gt*88cR4%tFF!ZAnT^hKUgT6M{c-VR=gXq1DknzS z;glsqEscFHc}gycXsNl84)hIid89T)QWNL-YvOL|ig5jycMCG#GZW4!2RBr5+v>Jw zaCq^FYgQCW;$Y?7>>9D0n;J|-Sqs3vQ)~Dcv){cm(!$v&6g#-|BPNmgAVtoth5-94 zz+M2MLIJdJ2hMe7<$LU>q_osrN2n5`H^`t?pJe#+zxL<-?#?krrPmrHmjtotj{UPI zKDepG9M-*Q$(FVXC0!7bQPp$GYSIUV$H}z5!3?#{efM^HB@K9n@ z!GFPH^W#H1{Cwc);tZ4cA_tJ8)$uze)QY~@Cq(n8Y}leM`K__;?#FG*FYJ=y>fcEi z)Rxk+E5+o}Q&|H(jPZLr!}4;<}5M^#KhPke|rKH zGaagM{5UgK`?yzPsPDlh7gWedNYYE@06H8;4rG<%%EkrWVwWD6n*0{e#rPddjB)nfAw=#W_Je$$1okiJkJEGSS7#U5neUcC+NC?2aY( z%vng$QN-A>Zm0g+f{$w9$G^E7i!JWEtbot|<_D1f4{UQeHa=ncNq19`|qzT zN-jw+jthIitmg+FZ3pF%w=@B$jMD)^=odn$*3C$ksR_=a{HG2fz9VG?z06j}fZJ*ng7DD=}d&wrXIGT_BYl#rx|gpV4q!SC ztu+iGE*tn-R7QYxKr$V#;uEPg(wg$H4NHoprZr~>W8>=kt>5u|>-FN+5B+}nFGnL& z$I=^Q1cSedIzKeH!^8MTvC_PTvBMh+^Km^ZFjK#lv=zRX@Xvu)S)p?$?+ij*?jceH zA9rc^+4WOh(Fw`&_TCEQc0LFF#1Ii4S!WzHQmZ#8;L3N1%VdbywhFQheuz)CAdMDD zH4yxrK-Iuk`k1b06Z5<3fCI6fV-@CV>H`nu4^lLRNqyH#yukB$X-Su7L~X0|39rFF z`CAiTina=p1F#`Li&sXhB;{F>Nz@fb>G$fzyN?|~g&l-YikC}_nYm>6idd$WI=i<5 zxxjDv20Tg{Xd2?8Nblp-@wHnB>uGXiZPw*s`4K7m#O3I19LL8yVvqYhr-W=p>Kzu9 zCwWU_V!u2Kh%jr)UFC%D)mTy^>d7FbI^vve0%1G$yU)il-x8}qGujB&M6sNtGpGM6 z{E&78>#S_OerhW!N!~{unE}93cbWd#sEjAtcpMbrTVg8S{i@@O{K0A-E6~L4qGau~ zusRIp_idpeH}566?u6aC$as-Pp%42yvo!!g13(pl>g_{7lS9$(CyWEy8CcP|9UVu$ zQ?38w8G4a|mm$B9kTblNm0-QAh=ahNNAU#E&$8mV#yJD>BycAQq<1!uk59ulqNizk8DqZZLTVsyR~Ms>uQNi+ zF_UQ(x86=C)BUIxUuuqP-q%R25G`MB1&n&6)P_n9^LcCEn+rctAnV5}Jt12KU^{?w zxf*dUyKCt~j$~*Bbi(R)KRqNCa*oS+QCcX^E%rO7MSV*{8#>Q3j`0fXQvbBxKYEIc zS9$q^0M!q-6DHg88uajs)H@{T&P+K1gXt_%k2n(`v??>&xwR_<3AERf{o5I#{;2~e zhXCte?QshB%uBdvyoWwtVyq7`jFvSE%-yWc3 zKgKR=#HFWtu!<4Q&3u*=`1A|8_DQLG!-2;=HIN6wO0?tWmO;B1WA&C=!*WVait+yu zI^~Rbu=b^&z2lKX*)u=f)-D0afWX<+Gq5*!mk4SU8S0NbA@bX7Faqx_aRV&9**oyG|CVG7Zv|iO+1^0}j}M4X4v#7b-G9i(`pIS^&z90QDif z9F|UIe>B4i%o0N<(VmCU9lzhtO?9LOT6_S`rS5)gLD|6niWfD^3czqwMQQ@{OZKq0 zfdKvaS3HadIJ%uPvcX59KP-NxQP+P^OfNkB#Fbq9;dEA(+JigwOFiHX6~Sv>FLHv` zE380z(Xe+KB=y>JX6hvdwjj)h=|F_5n~o!uGzUPXE(zf_3-x(xL1u zf!c~soZd|R*OHJz#wfu{e?DG-obf@-0*APQ4tqg_P$RMs-N1LD!*33PT7a>0Ap5y= zP6le-PxM**mkhx^uhM)Up6*S2mLF4Tmd%*yWE5NSIDSKV2z3uw5>Xqw^sN0#JY;|W z$rdfLFOoa?^lq_Ib5GGxj}J8>o)Icp-vX-tQ%`R~{{U_M{N0btw!zED&?Y)bxcvaW zUiu-5h9y~lPwZ>_C)Jdq>Y~|Y<4UC~B`_;j-Y3fbi8`5N$ZA7t1ICr%JFtEmL=h$i zn`jcqcP!RW1x<&!ID>!Kvb|KRiF{0aG~lto)ID79a_|`HPGL6NGPUTOwf^dW^Jvbz z?ad?sPF2V;nl|z=h!@A zwad={RDXb-(y_DU=L2n`<;>v2TSRE?tmp5`W!aR_%kQ=@z}<<*!rivj7$KIkqcYe~ z+4GA&VzgbeHcMzY*kX%=TFlgmS`BmyJ?o(QXAlyv_TEJrXtDLKaDDt?#Nye}fREmo z3@M`G8NwSbfs>-HzsxELrgteZbu07dkxToy;g`QU?hpVpnG5rT!+ax9xjDV4!wwU3 z8(R8yJNdaq>a086`M^jV2h@uDxIj#~$V+yY8#H)rxR3|MPB~J*OvJT2@aTHQGx%EG zbAuKt{)7=)yN5?NL9~bo(O=xA7XHcnbKDO&)~Q-MJ2>V{NSQ)hI*H6hWepZ*UW>{5 z(MT%<_0dBg(6etMYrxU0yr6z%R1xRb1qJ!v>)fMSH|Jm{ggE?Z09C7tf#gqv1KS

lNDwKhd1%oqYm=8aa-2eeik$=;4_I#)^{qLFKVyXH z`ByT33k+Qv5qr;Zv|Cx5*$jZ_9enRtd=?D+dRlsayuim6RG4^9N5yp${EM|TrS ziyTvXL!8t=Uro=IjS^Fz-)s^<2~biYdv1fn1%<-8iuc1Oxy>SgyM(blk*Y_(uWPm7 zcunb0jQ2se#nzT?H=3GU(xzS<_5MR1_Yy*xxQEPJgLAM{FMT5lk$v5y&|0@b4>W}( z6u5glz}=fc@oZ-^2pRcohYU!Ch-ZV6SyY~)k1ph9PE+oon-o#Wwaa*4lA)g>#lQUV zv){&qBHj0Y`*1zhP9WL$Cn$G63XqxPRmYa|jly zZgQL;zP^?cyMRX=ZbZr~SI>a;y3x{F{`r7;`6Gn9k7(7xCICxJS~He(Mvv42)uc7J zm;5Cbg;N9x5pskYoA z1VhBsKytV|wU@U)c%`3S)PJ4~-eu8a;y)1M5~ikuDdU`_pOcn~cZL<-8C^jws3rH< za(y;xONEb#NB6NK@*oyq)^FgH)s|eB?^P#Ofd|M%s|F%Oh!YUh51P{K?s=BU;5QOS zq_m*8;4dc`4U@lBNc8VFkn)mAg_YNZ&mDLb%;58Xk_@j{ZKrO&W$|q5uF-`>);t|B zs{=Tt_*Kl6fhQnMKMPpu20*Q-o0Q3n2}F56irJdJ_}}_I2=V%w;mCjQ0S*@TH+*FC zo|ut+il=~D^+j8MjKZrB1l+A_9W!jtGfnNZPNinT^<^Q6r+?lE()(pTn*pGt=xhZJ z78_WaWAw=V|0$?}?jXb`=W?U^g2z1ekk-~-cs20b%n8Qn3mOFtmTpXK4kP4@d3@*v z&`lvP-zA!YQpei>s7VSGqiw^=cuUL9zoAvO^Hi8#X1O@5xgc zR~$x8*!x?;N&u`AE+YpjyZihyCl>6jndtF*y+L?a2vV<3xQR3D(jh5Ex+~6PYof}| z#vak281)gR=cqzF^WPzXIQPGez#MQ-|9x-d0O`ZC;qb_fOFz5j%uZ=F&SW4Ysm= zZ4MlisMU}P6---QXpI+{;v}2uw0P?6QLJ~J2UUtms$A0EGR_MPaLmH-P$sqxr;{F| zsDm@Z^b0~fH;M;eKz{z)Cqa;)KU1Jat|eM)0JUDT(+r$l3ycL06hR$Q>shHq2W8Nb zcp|7c9o)Xn9;(W`WR6wyA=4H>*AqNg-RRRa7~`bu4;6MP*Kbvs$DcK4b9&~0}nY)Tx5c365f z1`Zloe76J-yzi=k_<$C%a`U-qx;i6YmYy~m(BaTchaGBlPrK%_IpTgN$VFLv-MmfT zNoS-Cag*G&N(>P;nL7Bbe>~um(20XO67)VyC#e`DoJ4$Rm6xGiw0&!ROPnMu~$7kaOA%Mh7=}0bpfu z3n<#yrDy@${ z;}<@V*)P=p(<^aAP(TvYLBc@-vOr8!T+-u2pVJYUh?Dqm2QdH|$C@klJ>Ov@qxRR` z@X7ER_$36Q=zQoml9Xh!^pp%V|DCBP%90S44crY2@uq1(A?sx=TxG;=J*txiu-%eWCh?Fg_0EguF8X&>W@ zH-s2MGL&R;tEu3k$dH!kPMWhoZ+hyszgti!RGmyGrGbU;Iq3XFqYj z*?+ozz6C(iemffJr9sgSrMkBLkr^?A;Pw4v>ccb1AN%5KIOf_Nj%$8mLiY?IP1#LpzPJOGl-t*YjadbQCT@lxyPByM7~+3Cc|Eg2oj(e>FG}wxE#yH za9Odtl9KbwJ#RCjNh$fR)RBiy2PkHo0~bit&_vu!l^x1jeSe1u7g@t3=u{OOI9k0G zgW?oCDZ?xYpEg%XJ$K}ANG60#Z)DBjx(J*)YPcw#xFoCRy1d$MF1Pt@Wq~|$_s5BP zh!ZV-r&_X+2&YNI9d(m6WW@?o(^mT2YXz8TQU-0?%7iaA5?iP?l~SS#!gaVf>DE%t z2z!c!$;578>UlUI!cQGy(ug)Ozlysyg+Wu|Ax*CLY}A z_xIWosloS*gA9QK9qsK#AE_4BbrCLFyvJx!x4$M5;dD5mMX6fPHh@JBwO;56xtZS% z3YcSl0n+B^NFs*xD?KWG&z;cqUv}Xdb6OQI^;$*X6AG%jJzMkpNCxVU!M{0gT zD5aj@(1c=x?*a!=1SkeS`vZrP?2i%hpiQ7s+REYRZMmFQ>bOW^Y)YhHv{fc4T96tg za5&#}2yI}aJ$)s32OkebA?1l_sbIRPTTT2?ves+avM(SM|E;W6Q$VM07J$!BEO`s+ zEqQCkNBr7X5_yv zxBp@ZKMwaA!<9eWL}8F(RcGz@<@OFOF5D}ih_k~nMoeX|Q1mr3!BE@QPoI%6QxFZp zo1mvMyGw?#YR|GZff;B~1l-NkRAB%8UL-MuQdCKq$r_@8XuY>XJKZVznpsrT z%J%`zLsO27^>jQ=wnV4{DslEA`Gz}cptYP9Vo6qupGj6yPGmHQczURX49jWbed-k- zNmoUj(iSkF#VIB=nzQ__EHD31TP{UaIWIbaxh%e^D<$OR#obm^-~#SA3INIL*OiRY z;Beh@HweI`X}{O%d?-(6^X&EF>x5)qkS%_BQfQIC#ppe|(9Ke7y!7>9<-JPU#q06% zdrZdpUQx}_SUm(kdhyxzyZm2nR%%CvxEdNPgclylSYH)s3Sx1^WhauhWwN%B;NX=! zyzwMlB8vdZM~G$g;d5169HU>Cg=%P zRXh2lAHehufXfbXA%mN=&v86(3^w78T=7N)lVRh;zgt3dN*4PJ;pdww*%27`MzEShDw>!aiU0mI> z6e@3ize&EIFyxP=26Avii>kvOMykq{*EUA@`%mJdiwco*z#$rO@Rt(+6&73Rc-}=d zMhaoo-#Z|ZyFi#ujq;J=3N7j3K-*Rsb0iq!Mj@*O{io3VtNzj55khIhvce(dG<9=7I|F_nPnn{`HtnoGSh1HLv zF3~@e%>U^<8SGj(P870N6=}{xbi7^=2wGL+zlwk5pN;_ zRPK^hbO@P1_Q%<%DzVUV;#+Zt67HjBkQ^M3#5IDr5d$0C?)oK)0QH9eQd-N+e;QBH z;o`i8<0zzLkbapumf~?~t69rID2f2*Lqcd!d%O%`$lm=}=PG}EJjS$G%+LM@^-<}X z{WYkXK3v!s?UoC-EDOe66pp)QK5IpRfLWSXPn`&F;~bd?&%p?=oM~_uCI~Xo;Jl41 zaaY$cA=yFH`(NH5+}DF4`FWUV_5amUwECYsaNe!f8mXvTIGxWF_iR5*#V7~wX8GIG zsQb9^x-!0oW`8#WF~*>jz{`7O7K6Y zon4bglk3!r6*%Vf^Q(c_|0}Z`9*mPVPC~CvwBUmPRH+}A|H-G>bU$75D}fcXFMsIoX?23IJv`O(Ub^@20oyx`H_}{P5dJBJ<=j z`A@?!<6yqz8Ml&j%!er^f3q#6V8Mlxe{AefZExHG`31|FM}{YP*o(i?|LG;iPG#LF zR&Wv7L;#)S+TK=GQ(SVk!DH2OL*o*9)uXJoQ1qX?S_7ea1F?)YWf1JBjZQIK34SBd zDkX-_+~Gh@CTE&yB}DMHtVr#hT!rBV*f2h?_=iE5<4ygn1zeQPsHXx-l+-{MXa5M* zKw|nIQyWMrf?jyevIb1zLw^3Y8faytDUj_79CuNaKHJJke{438a2u+XHYTFZxD@&^ z#`c-!=r1}H?u@_z{MKxu4+hFQv;+pIs!cU!&r+!{*~WbHjSSuUE%N1p6p5SIGNjiH zy~cH6={dN$C<$PM5HVoH0V*1V*ad9Nw2*%_!+Qzm>nMi=jT zxiaEBqv5)E@1;3iJ5~JQNkTx#UTG>UM^K8djM@ngZ4li&O7ToB1T5!X#j6if6`w@$ zP@pC#?k+BTnu1H^!t#(km)ZNo|3$ydjB_Kc%NaP$yzOFLmr{~O(>;iGb!P0Y?V&{2 zAhCLb^;Yk@+-qOti!#SFU_&aMY`S9>BzO=n6LM`sc4de*bszKEqZCEl@xja($*=+e4 z`_tDvR^Df8aoYm<5HIX}xJ$oMV|EV#MWJ(GV3~Bc+78E?H~{hX)r@51ZEVg=7W?@h zA#!ieMvLh`zc|dfY$sBd|k?XnxsD}d&*m! zQ+zCYQ#2InvTd=HkodeFUOfJV8rn?N>Hv9E^%=whWrK31i}wsj2+987uX5=G&OO@l zNxahk_A759VXeA%h4agmSWGZ*0zZ)(RzatBO*p3>{F5R&4CEzcq^8}lx={~BT-Iwx zC>A5?L(~Gj-rcv!>~06v#Zq<_eVut2XBK(^)s4X*0o2X!d zUSBZ=EQ;^vG@EaopvN?WN-)UO(clOSNOr}zh&FkkrPV%_aLg1uIM>@NGJDmu2Wk9v z;jLSS16&shOhe=8J*ApFQeZ&Rx?C2Qt8cIJG;c@tqGOtf68#G}(nf4>+a%MH5PD7U z8lIOfgB9gaLXXHTEd&^Fk33IXbkxf`KfEuXMTWx+Q39yd;`ar(Z__Z)H`sl%h038- zw^!QOMDLvP@-Jiem{gsR8F!?D>6Y0z(^TN>RZ?o!-zTFA`zjiXk0>y@)XUDC#U$9I zVqYE!->hA0#)gEKm%Vefo7Fwi@10sjl=6K<@FK zms-t7yi5&)X)2xmS6^Etfu{-U*X>N~*#fS7?b=o$W+r-wlx5cl4E5lys?ktmcBn^E z+d*h5yujkOTY>=Y5`%VfesXw(4FEda8u^9t`e*WYTMcCE8YzULx^dI*-a_jl^G*g6 zqGGc^rgh-g-rdzRJPy70w|9o@Qw@5g4=m2Jca$@9zvGU`qC>w`P#@ibmMqNS;U zlOA&)Ad2oo-ZZunM)cGjBI*4RF!5q_f}I^lBs0oTm+{lS(>Fpjf6*^w%H0#GPDINZ z1cu|V53}6@l0Z@T1B$xd%aAD@P8${_w8x){om$?PlrTTE{{~^yF*n%NzeQUH-dy0v ziYs%V0&B=vTPiQ9cs4Xhw4+?IyOLv>)7N~K+xXN!no~@mT0s=>*TB*l#*#W*1Zp+` z9DufQVKXP!;h;Rm0HKgXpcE}RmQ~(piJ!~lG}Kv!5b@pBbj+C8QIrJk;)*s=rwES6 zX=R^(R^UbKi=Nid{oT)6A0ipkd_j}$8C}EuO6b+(K+46kwRKhyXxT@SMcJ@Kg z;@Zhw51a%9D1|MIu=vCcNd%TFZc)Qs#kliNEScHUF+F=o*bD;!?>V1UW-WP(UU;OS zx$3BLLX7U@+}&0U)5VA>M&NV1cdSlNfz7TV@5#rr*Eenr_SX-1ID|;&`_2H9i~&ah z$>DbmqpsTsR(rM5iYTw(M7X(+PdY}DH_7WU1**^}@Yh8oW)*UZ9|k@|I$F{r?&fP0Vx9)wwH6||DB?{h5y|Au-=U)UNTfK z)^n6#tCuRDQ?{mnRr#01%yBq8e)+GrB4qOvf%HxP4_@5|=bku$r@fH*j zR?D1F?dQ09(y`f&p;NWQDu>NGvgCbA^(%>g{)p@Th};u4tCenz=&dEop#INbo44&2W zEgC&x`dP`9V#LXDf``0txp2Y;60U4B(&Swpj zYtCMilN)1Y-AkN17%B15@?9(`F?GQmD?O&B=dFWq+yo+SCMr){4tG>2D+#)fa(_Fi zCC_%)JGQNo9F%**rYIsKl8C2FS@3(JG^_hFanh2uj(VF2*XJzq7V3+TG}hEQQrB~5 ziV}{Qx!b$I(X+YRopz5A4KEB48)1pVgc295piM7%VGD}iX_WMOmY6y47eNDEiFXU4 zwxH-v`8m-;DYXB)6#2Oyf=wjJznCmXr}VZ$L+mxx+PrtNR*orAV@ay(@2^#KaMS-x z4i9dNs;i`FA6phMHUFp&EAlEZw>cv*_x)akG#)xmLQL9Tlv2?ttH1C;VcK)}U{pl@ zBcUb))mAf}@>0azi(2k6T4K#lS4irFo&7sDuQTn4cjHE%alrH6Z|1aT%TpnC&CY#z z-VpH^UY30UPl)y~(>C4xV2`37e(9qa|R_L-{uFj;@eaes46&vDH%e-%W=5xmWHKx_VP|n!@X()5EVmH#cm@T#IWoudd^=CMdZ^8dZBd zftO}MNiRqx-WB=Zj4VThy=M#!0uXq1h`W9$d)Ld3K8MTW-lz2yquAVZYN854TUO}7 zG8scim$>~3F-2QDvGd0lE4_#VLqDE@v;a9*Qi~`R1kDM2ZbJXj;=W*1g15p+PaJ{# z@V7I4Ot=956AOWk5A{vI2TUV$Jw56h{8zJnUk?sC1y{XR312j{?7gAdw)%M1WOSYy zB3NJGpD>(r%On5l)&)(L+YJVrPcwPAWLPpfrtJJR?CL^+;Uj}R6bq6mB{`u+NmT8z zfXByiBh^|Lhb`XEezVEK{K&10cxR&9R%w-zMG|w`TV&ym*`i)9oYgcqX1>Abq|L4T zbl21W-OSSq!;Mnqt%)8sQAe7p6S}0h6Wzl`uzwNN;x}#VTDtY>-{Zgn8%<(O61uI# z7WXY5U97kA=Iup|u?G>fYqxi5z=;$=Lml*!waMKri&RAmexQ!utG=p7osL?%-PGV#LJ(u;-iyd@9;DNLI0qOHqG8txLtQ>Uj<3K>TdpZA|tqRqhjz5xpjZqS>OS4TV{eHTcoO!u3 zSt6d12%wAC|5}KaSVFJOp2(+_j#J{vur-!!oEWkVdGgvluvGZIVeA*ZNC{}k_QQ2+ zfVsc*gKx_ztKQSM6pgoFv`qJcSQP)-PT%)j#KRLaY3;2lUsLMb52e`~8!jW*ubR9c zntb@)l@6@b6S+39_rIWJx<4ZM>Qg~E&6gCFx1=N4zVO)q533C>Wv4$k9zwaY`6wTF zl`sse$E>d_OTooBYiAYq_aQj@YpU_}zzZh13<}{AeuOswRTtw-hHU>YqHhhHUjeqi z-wi(3LE=Djna?e{y03vw<5j>2!CbTk>0RX)5f}p8(-v#SOS$D%I!QGzOW{sDr=^U* zO6)z%;T*S>PPXI9TQ%$CyI;IDkgf7bh(~C)oBKXQ=dbGmo0;j)c>ZWgt0DIco1;=a zcV(rV2=9d2QqeBwKK|RtVUvrE+sYemtv>Nh<;=||LYrT+RmBHgIr6hDlX)+De^mnC zsmNBY;=Rlv?nSzYMSA~?o0A~qXr!Q2 zw|w;1Q6B1=8b>4iXi}8aVbJ$9mDSMlafpQP&Ym_{J=`ds$L zChJH&He26v!DK{t=Y27|C!CXYgNXf@#o>@FU{a~qr`uIw=ohDfV0hu#XAGu9$Ycy! zV*1W5vo=3Zfl9;5!PA-}Of2+;-P8;4zw?JO&`@}M)Du!mgi`g~2gZ8?@X{nNR zt(v-zMJ0;`zE+G5xhd$Z6Dj-mU?rI0X=$5*jqok9f2;nS*>SFoNoppPlE8glSy33y z87_{<@Celb@(x3ESG&SUgJ8d%Cm!%)Be8$TY6KJPgZMWG-v?**U|D=}T{X;Ub;V^$ z9v?4T4IZ^Y2A&?GFuz(}NFM2UNuN04am!B?(zcH5wx54KRd^ZPfc9vSW3S#FnwRvWPmZRA3mZy&67 zF@^b+R#R^z8bAyaH&-K$?$8_K2|-V0X>x zb>Cxj!a5>cfdW1x=#e|drJ(&#Cso;z{gVnp$kmAPeQ43l>D=$TlW!nSS+OYRpO(XY zCmR?W!*?&*6NNreZJb}N6L`Jm0r-pP;YHe^6}YcZaq)xGOwGMjBT|>XDNm*Tt=UGf z_hw;EZ#28|z&Gf0u~F0NeRVO-xw*wYzH2k)&^_`IU8HE7DI)+fN4zOX2-Tp0>SP<8 zxr_)RqN~HP$}8UxcD8e-iXE2|ygh&m)>-4NSJs9%a`}Oso^aHmM4#y|zB=fNTX;FA zy~|3ZWju19K$X(ZKO#hLg`1~3&GUikuWvdmbDuZplXN-t-oC@=%=0{pIrPIiB#aiR z{-af8gbU;%*01o|ct;hJ-NFqSy=MD!h0+l%sJSCv3c9BUnm0jxVl8?EudJfkWGTcu zZJUcR!eCR11<_MZqG68gs9hc1u7@vfT~Ju@c|~FVQCD6oEV&3S8~-?;TD|bsA3xXa zALkc=LiD+pl~e{+Im-rbt2ZzS>AijOp-HPK8vfWnAjE~Ah)tZ%mCvgP6U`Pr3&}?9 zI*>Lzj8)2zp)`z{d;wRP@@R-PR?zj54Is(eEWENo6}wTs7P;B<{p)(*@s1R3Ir{T; zxkR0s)U^S)kGXGtYR)6qIM!4pThEYhbXj!%yq)Dz-ALEQ4bs*j%I+w&T-9UU6Q!Jz z{F!UgQ^`EWROw&OHR0K}sv}L5{v)Fzk2^u5ut%( zhuT9h**hh#_~|pV`0%Ex&-5)DUA=N@srTTj_vDJuzkeIabrX}LSvAlrI-YgOg>146@DYKXLU$GNZ zPr7B)JKkgK5bCH9a~Ak_LwcxtdEVnO@^xD7j+1B!joxR9M`V=ZYX-WlqV?Iz10@Mp z^CO&nWg-PBSfd(w?$-?%z*FZJ#IiQf^mTQ=8$ zNxIsvWItg#O*b+mO}Jg5PiONYd!nX08=dzU1ec|?CTNIVYqj}W@qr<>{z~sxoKzcKYuilC->zha;OVOzn#Afw&9`?P`72b~irNKBY-PW+{T9G@r z*$AncdkWL8nU~pU(ygb9$$vjqlH3wIXZR|}dg%H|{LU%hv*oG1!q3_W<@;N!rQ9s_$vQ z(jzeu)4xICm`rdM6>Nj=GAlmPWZm(V>^GjJ;@r269z2slrnkz9UJF;YB6eM)@4pqz z?O(pWB8Q2nuwHKhFM{rn#2q*71M8A7+~8u$`1Ph%lPybRK_rOL9P{~&%B-O;5(&QY zocq@4-Chs@V%X8PcfzK=icFO+w-Y(>j!0$R`M%$^bNxG)iaefrrt(SaTG{TiQ-KP5 zO>b1>3%EK@Ym))CAS5U3$m8{rz?K}m$=NTk0hn?UiWSM_4p3wAj^uQlJ(XdBu7LgYm?em zp7O|8p2*>nQQRZ5bibAVg7O zTm73-Wgiy*=T6c~Hl~fIfd>>hck0|~%CDDEE#f~qd!NkOM@+OZm5quta(huxg$?vGB3VF+|`5x7*N#X`xpjeyijO7B>s|SpR@P< z?6n2Y$!`m@r-ay|!UMZEPYl?UyO+UCU37%!kIQxqm*&3gf1^fivIreaM{S+teq5O` z?j@#r%l>iD1-TR{NZEQvR@X@h`R21X`wz26`U;tS@fqAfqtr7xi%LSYv*Zg$}V`{6n1!`a9I6cxz8L zuHc6@&%dxe;~OP41?4_#?gL*@SM8r40_Mq5y)Vt2vhsbWn#CuNMAR{gg1_M@Lab*| zhb<2~JjCjfUjeZ9yurXz!i&dpQOR9ziCtccD{J|L?|q~nJ(%L7Mo~#Uf9Sm+F9!^; zdiPotcR`Uu>Xc-^Hg6;4ebas0smU8^yZiOmpgtug-}c~7UXmgQL2R?I!JlORU^(g! zrek*Jg5r-zfmeC-`zsoB2(F=!yEoU9wC^tyLy02WjNdn2`U+Pq4bS{|d8PUmmDZ4> zp236;=E~n2Aw~e zV{u>=8dZd*n#tv%(y>RT&4Pc*ADLio$6)y}fJEHAi>p$Abd#x{tlX3*QdN{UtN+!JILpMI@*^!{&MbHZS}5 z;y`0xz2W&r7Z^1lM(cPYdQuGx0I7i>u43YADqsGPm)5qVuj!qy{qFxVvtGGNoEgae zju!f0b3NyJrdiAVQSp)VkR$t{!cwTt^mU+`k9422A@z79yVX$6^9mC5$f|&DI~_k8 zLfm`r$|<$Egs4s1=~kaIPp|!P!sl|m^T5WbgecOnL0Q-u?k+{*y5j#+-I@PG_5T0g zBtzMfC5&a#D{Ha~! z`(J#2p4*w*IoEmKu5+E|^|;>;GH!Vt>!hqT+@-nxGSe6=UiBKuPOR#iv6lU``K+As zT_#~W6?{l6WnP2zSCf==4jiEut9K)Xo;pEYd6AK|*~6;H&20)td>E@XCt9_3^)lXt z*U_AXT@gDyy{vcq$)R6Yfl7j$So?CR82U`;9`fdfa9Zzkr)6`2aVhlP6_)*t9Lrx7Wpmp4FrOPmTZh?R z$+j!(@s9hNkN1%8)^`P2WDHX&j6K{d6TO0>tK8lMlgG;0; zw^^yrO{+CN=dlYr9d7p-!=0^;^gO<>E;|yjqK%w26Rpz&LpaPzbbar$BKkFBE9NVt zq~voSO-jq5H$PDLxi{S5;D|ITvEQ61>QiR@o8X_GX>8RKqArLJ>3TJiNr|iZ9@>#B zfU*2~zxJR$Ocp+N@l9pH;+gzp3Gsg)SUSIm#<=IJ#A>E~;>$^RY6{iZ_pte+1gNx8 zk2X+91DPZhE5t|tsFgb%a;jmL&LWje)j8x?3fn>D_TqCaJP6gNXnGxfOk{`M3L{p| zhR!iC|CjD8UGpGfCC<(UV;`O+-g6I{?S5bS0G)qb+x%YBJn{2 z?;4*;I)x3nZ~0*K=hU9J-$e#M*S|8Shqt$zBhNYr_gw|ACi?6{p%9+m#7tznI*OKq z0u7j}qdQM#&Mai7SFwyR9>N)_4>(- z9Q(B~BDxBiG(5(`bHfkf?2QTcYC_L6Stm)ku`)L^q*JQbg+m2W2l8(|!SnM0Y9KLq z{?(C;Gy_R-V@`2MXUzyX)Ag)7)@FnT9(3*KaD`b(SCrR;QL0W9_fAJzS8%o=meQCD z^N&LfGvCWO9@=7gxWDf17e=E6hn2DKLM-}2YPpMAQNyqsVd8=1yVx=GOo7q_Ee0X$ zUuaty)l02c8NV8VHlw_$S9If_kz*T(y7A`q3EIrGBqz`!WoB#6r8EMqn5i;->>uSC z9;7*B&J*GK6=}c<6skzJ$tRcuOZ}}GSByp)e7Q1}FZ8gp$S+56gk=9R2nfS#+f+vj zEADY~HrELE$4G?kN0+4A(~&(_zcFIealbXxp->39kxR9(jwDxt$&gW(ci z>h@x}m5$Pf=}B`;7~LWbT6qfDOJ*maT4K!y{zC0w2TLzL+Iw99UgYc%hLEs*ygB}P za=)9C%J*WBsTdAU7u=>3+}!6eBY*oe_d`ydn#FZxzIHZcrMSYt=8}Cg*=7@c851x4 zTg+;KW}H)P;xHryZdpr##KzwhL^wSIXo2jpSQ($~jx}LAz?yr@v@ZVxQ$HbY=%mVU zpvE}=he?}9!hW~p?@@Ab0a-G%xZ`rF0!e3oxf3>K;?)B_T`xqrF2b7hAlWboW}6P^ z$7#hd*Yo27g+kjmV~+{h-M1buEf1%fInK#`j5nNBNcFBL1Z;|3rQJ)4UM62(U(?}) z0_)0;O?LlXNVV}_T6aG_{!<6!T{aQj-Fi7xl`NECFON_i**>3WL|+zT=P1rBI@WO> z2ohVHO1gI(K%IpVXL;Gob)1h0w@VfG+9}z#LuZR-EcZy^Q=%zcI z;4mWy83X>s2oQROkX72WZil-*Mpy0PCgeA>y~0;!(~w0KTu&{JYa{qB`-wRgS)|V_ zUeZ&dHA5xRuCY?!Qmm8y{QLx>n_OO4qs`E^&p(?n?W6rs3gwkp>-1f@0oo?FbX}YaN+>R4^_vDsG|I`p&Kl6s4qo zaYtdt)(7i{qN$cKf822a1Ngv>)1j)>x{K(lGMn7$m6ucUUTD;wDD#aE;l+i_U~w>D z&S2aaUG3Zd_|Dk!`+4G_`?WAtBO3D4w(?T-E6KO70LaQlZ6kQy^SM?YW4zk!3a>5! zO!w$(7EI{|X&wYDEYrq+3BSBbaQd!w+m}A}KW4ya9B|$ibK|q%_@@y`UGmk`yV-l& z{uxffWoqe)zMO^4^TBfuIb|hDTQ+L7;^#>vMGNZrRwD&8cxp@Pu%}{>kxT4*$#yF? z4vYcQ1)y-~#{>bjeT_+9-bY-Of;O^d|*GrPrV zaWKQL-;}dZ!I)n8R5R^+yxlXLo2tIzF>$~Ldn3?Vsu5YvNxxXFyXkEG$)+I{H_2KO zzGG=1)|lw{zZ$%*y=^~yEZ#odBe|0*JfXwd?*YNIM7oOTl{9H+tC&LYG3S3=TmCTk zuws0P?U84uP|1!ZQdJ`~u|C{YGAh$jfyP3C#CyFn;TexjVU~mjUzWdKBphMr1J`)+ zx;84aL4Rh8B@q4{if-E0IZwuiq;p}^j#h>g@@x##w_e!mD1R+(KQ`OAWYja<=KBc> z(qknypY^yg^etrM3Q%EYKmpQe#R%Nn+b#=^bTuzXQrKDvpb6tE z>{(3yGxaPK7#Ry@a*)KGxDW9#5ybI1S?1ifPLyU`9~3wm{0t+zgf&afa_HzHFk(B zyja}!g$MD86nXT*vID)8zLm!kn!Pg~!UnENXLrlzDgUtVPIK5S^J7o5g}abT+Zli+ z4pugXNdXgh(5)DihU^2RO3pdpTbC6*aLem(kX{V&a5ieCm+a zo?_qBrdNay3ZDnM34fkJU2^Sy)ZiP}^(KFl^%F`f&Xbi1$!4Y^XETm`M&Q~;Vn46e z&$H^=qLT{f>%g6;{|rBf=oS8Wg-94czQ2F&OT!kefmWdZQ%178->huW<`|xrjud^6 zo(Uy4@u*-aP=~OG|Iv^pyIAt0tgNIP%7EAMq-aqeaJ4f1D{rzH#^M5^ZH$0o+7B72 zRaGOoMJT=cuc==$G@!jfaIy9x97gvKe%>AM0@3>TL|n5$QV^>uG?2Y}+QYBqd&LVH z3-wYfgQmT4L(OTx3Mdima#sP@5~4F$FUL zCQ0j{CqnU`4~v|vB;}&6SGRPTopiGWmZ$Z|*?4l=`h)G!=Wi!Jztp4iCBw#AxLqB= z2|Is#;#obIf7HcR!uAEJ^3FD3O_nDQJfh6+n(}uzgQBezvkCu-k-3#$KCGwL{xFF> zK;i<#%O3d&1usD=*!_LH;thhj5Twg}rKN|8M?MHA_RE-ij^e=e;`V(|5;ZyVD`RK= z+WT^HQ7p)`&HHjT4RS$sji6Dt_qqE$h*Mj>tABXk&2NO)Dg1+V_L@`rJ*>a4E=uWk zKV3&>24LnTvWz-&{rL8g=F@8ZffXSDS8?KwW!sLWn!QdFHd`X+krcBp@%=E(`c}{7 zKJ`^d1x_|a;!~Xl$`6o3-pL5UqW~>Nirzrf2D7qR7vF-rWHkG8OPF-^KV_ASEHI)(F=V=S0%{7zaG!6 za>mflUZ--g+@T>s5t)jgzkTDvaPrrRGykx@c&YPG1jCrOW7xIjr}{F{$qj(d9k=u? z(o$l+WG^HzGZ#q@J!ipZ2s>l@GXHqCq~kM3@-Bri{#nLjM?9KC)NIr06V#gft7!p$ z*q;_Y-yjccwQx}1)g4oZKyPEPMfmDeSC5!H)j?V75{);G)#N^GCC%4$bR{~X`Upt;- z#i5ukbKx;*rP`Sh$LkCLql;Y@j-F%&C8b~Z2blEM>C8S+8E%&2q9+v^5q^iSIHeIA z_TbBMo`>NpIRfcPN~paFtMFpV>-X_Oaq|}y`l`>h>%>A{Lyta#nKv48zt8x`MRgUe z)l}?mb^Hp7yFJGFhO!L>P0_#Lf8}o({ZK*D|G<}=mH@jseVjtE5z4GwRylIOBOpE^ z6dt5;7f(fPj#-0&O9;?G4-$!JdUz3y#oHLC@OA{-Ez;if34RYg_NaQ^$#3^jR57J_ zmKJbqKM$-bKeb5J{^~P~R-B>gyo(5Z`RASEqufa-`~wUxIru7w9C}&M=#qA4HMGY_ zCS7jeU98%I8#k z)!74nf^pz*1i@hBim$51Rw$N*O{rs_Lxq!M)Ku-=7QS-wNmzmK*-?uDZ>krUpHQVB zx{Z!~XXV|zdJg|N4&oTpcHcK4d3o?mo__*3g|CM z{+zG9O!NNUrCt?ZtQ1XKV@|Ut7~bPRHb(DC#!Ms?iUl zinEnh|0?wk_&!n(wO$g`X>7QJs#>{SIFOzFiXF2oFMaf4J87w&<8!Yr!eHI}WKT8G z;Z%T^b>2-M*p+GQ3OW|6Kd$_%IG81A5mn5wRUR#zGJQ9LkV8WT%a{P%7c8U*xu74T zs^oc)8ryC+c>)NoF2{zF>S1j#n<4NLI)HROqrXt{gu;4N%{eWn{)W;$WKO7~haPpO z;0_)eGIsVrspmf0&j}j3`|hvxoQp*LlyPwV6b+e=SxJpk&agj9e|{;F-~8w`vcAP~ z*Xzp0VVVs-*CI|j2{JG<`A)lM;lk>l;rl4GLHSjUDfh@jEFNx z6jtbiZ<7`>@;7yCKUl@z1CJdlmykJn*1a{wi_AO12qn6@qlc|UW>oVqL1{r79BFFLK!-AhwokV}8jWZf%+ zce1{4w7O5&3m-3*Sgakikn@Ax?Hu6iJYJD@(mD~`d4keBLp{MUo%N`V04tlG^)x?%PZwOX>1a6=UXyy4VB%i@R%iA@#K_=N>QLX`m(OYQ| z)iK&sfufh%N8c7VOg9U2j#slcOyqs;TSRqGtNM$XtTJ_39cTdBuBET4lD)NQa?^r} z40Q}R!2CtV(H~l{>oRzc-=IEMVzk$=oVeofP=u2PXwcXTXa`+?0ks>s9~u%`naT4- z!WtBxh1Z2=pJb0q9fl*d^vt3&l;yQXdOsH~CY~)0G|teF(O1W!WJtC3K#W+UDLZ~B zXX>bE>cU9Mp%9<`{ON>3p|mnBz5zlFRz&;J={5cmRBp9(O24-|m5ht_9M5vL?^)JY zVfRQh8-*VW(1-2YCRXZOt8Tp=jc_d}9vMg(b4n7dFu&-eWaFit7CzohJk~I>VETKX z%UB`G0G7eutah0HWTKoCG?p7>nUnQBQc2wUwX{b$3S5VwQ&_-mq z$yc2i_>PZE@WE5@?3=iKvH!UX)EKX#ss`t&mK8b!QvFKgernYp{C}n|zfF&Ka2`FM zJ$(4rDYi*{BbR;t|OfRftOT9 za1FU?C_>2Sv=C->nZKRI8WQcx&out#`kbz&fqket2x7;VPly41*S1+q;RX@+<&ww_ zZP=YsZ|}1ou@BqNF|4fcSbku}6Ti{6x(H9E66U9*@Lc=f^W@&&hF)wGaiS@3D>;G*W!{>*GAgy)xTTqy=-is-K=$Zho0@;(W6PSx#?R|OTTfb!{V_>Kp=N;I5tWdFEUFUu~4MG}SkP6hIxj29FiWfYh$c|QsmfX)R^RkP~fUz#_M=uC<|N5Rg zb0iCOD5145wx-4Fc~6Ojq{!nl{w53S(D4gy80o-DAsuu;VY12NSO`Xt@z*8F1`T<_ zq(dq&(sl9L+RqQ>e<|r~+F3Zdde&LKf-uXJZ`uyn0=0GG^=A{tC7vmvs{yJUt<(;zi018^k0iW4JM4eeMb zYSI9SPTxNUT12ieU+w&(3vO~fWIJ1H#qmBjA(p;Eufhfwb^%B zp@saP1emjFajph3eQMv%sF5F|B4YM2mCMN7z*yW{zroF#KCj=)2+1bD<#tw7vYQ)9jk5_4r|Mdbu*l<#OCPP16id$Gb5sIG%T$G>QI}2x zu2pjFvnhsjG9Zgvs`M|q*(u+~yAAGGXGOImgLX^}yaHQpf@JbT<^UB`UGVFc02JGnW4xdqNx(q(?yUX?;7y|absgT`C7%0dUF`T*d~**L!a}FZiY%3HloWZ@ zeKSF(56TwV@Guy3YwR!6VeSYK?4ow$F%1voc%LP8)JP}^S7qI!>WJ7+ za+Iws93yz)AKc%e(W^?*K@96vH=Zf$_YSs0&soKPji9HqU^dLbB&`{jeoZXk4NlO) zo;9q?zW;OIbB3bzyXH!IlL$_@=q?uVe7U~hH_Hw9@JuN4O2|Vqucpv(L(2&%TmD

olsn%7+vELnH!%>9;n)5c=Bb2GiRr+5h5CP%>N0EHokP-l zw-5IqGNFvaRD5o>|N0>T+p>%0XY}fRS$ZR> Thiu@N7ijKj>fb5Ru#Nmb3T@G6 literal 0 HcmV?d00001 From 8385fec954612bade1cb947f7b72bbda37eb5652 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Tue, 18 Sep 2018 07:47:36 -0700 Subject: [PATCH 20/51] Fix small typo in docs (#2420) --- doc/related-projects.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 714fbf98d7c..524ea3b9d8d 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -47,7 +47,7 @@ Extend xarray capabilities ~~~~~~~~~~~~~~~~~~~~~~~~~~ - `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions - `eofs `_: EOF analysis in Python. -- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. intergrations/interpolations). +- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. - `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. From b679f4a9edb2437d11f28bf6516fc7aa8d673acb Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Tue, 18 Sep 2018 21:19:07 -0400 Subject: [PATCH 21/51] Adding data kwarg to copy to create new objects with same structure as original (#2384) * Added label_like for Variable and DataArray and tests * linting * Responding to comments - fixing up tests, more flexible input * label_like --> structured_like * Made docs changes, added example * Responding to comments * Moving from structured_like to .copy(data) * Making dataset copy mandate all data_vars and minor tweaks * Stop ignoring data in IndexVariable.copy --- doc/whats-new.rst | 4 ++ xarray/core/dataarray.py | 71 ++++++++++++++++++-- xarray/core/dataset.py | 114 +++++++++++++++++++++++++++++++-- xarray/core/variable.py | 114 ++++++++++++++++++++++++++++----- xarray/tests/test_dataarray.py | 12 ++++ xarray/tests/test_dataset.py | 21 ++++++ xarray/tests/test_variable.py | 28 ++++++++ 7 files changed, 338 insertions(+), 26 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 881ae52cdeb..b16533ad33b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -76,6 +76,10 @@ Enhancements - You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset. By `Julia Signell `_. +- Added the ability to pass a data kwarg to ``copy`` to create a new object with the + same metadata as the original object but using new values. + By `Julia Signell `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ae3758f0bbd..937d38d30fa 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -677,14 +677,77 @@ def persist(self, **kwargs): ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this array. - If `deep=True`, a deep copy is made of all variables in the underlying - dataset. Otherwise, a shallow copy is made, so each variable in the new + If `deep=True`, a deep copy is made of the data array. + Otherwise, a shallow copy is made, so each variable in the new array's dataset is also a variable in this array's dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array and its coordinates are loaded into memory + and copied onto the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored for all data variables, + and only used for coords. + + Returns + ------- + object : DataArray + New object with dimensions, attributes, coordinates, name, + encoding, and optionally data copied from original. + + Examples + -------- + + Shallow versus deep copy + + >>> array = xr.DataArray([1, 2, 3], dims='x', + ... coords={'x': ['a', 'b', 'c']}) + >>> array.copy() + + array([1, 2, 3]) + Coordinates: + * x (x) >> array_0 = array.copy(deep=False) + >>> array_0[0] = 7 + >>> array_0 + + array([7, 2, 3]) + Coordinates: + * x (x) >> array + + array([7, 2, 3]) + Coordinates: + * x (x) >> array.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + Coordinates: + * x (x) >> array + + array([1, 2, 3]) + Coordinates: + * x (x) >> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])}, + coords={'x': ['one', 'two']}) + >>> ds.copy() + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds_0 = ds.copy(deep=False) + >>> ds_0['foo'][0, 0] = 7 + >>> ds_0 + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds.copy(data={'foo': np.arange(6).reshape(2, 3), 'bar': ['a', 'b']}) + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> var = xr.Variable(data=[1, 2, 3], dims='x') + >>> var.copy() + + array([1, 2, 3]) + >>> var_0 = var.copy(deep=False) + >>> var_0[0] = 7 + >>> var_0 + + array([7, 2, 3]) + >>> var + + array([7, 2, 3]) + + Changing the data using the ``data`` argument maintains the + structure of the original object, but with the new data. Original + object is unaffected. + + >>> var.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + >>> var + + array([7, 2, 3]) - if deep: - if isinstance(data, dask_array_type): - data = data.copy() - elif not isinstance(data, PandasIndexAdapter): - # pandas.Index is immutable - data = np.array(data) + See Also + -------- + pandas.DataFrame.copy + """ + if data is None: + data = self._data + + if isinstance(data, indexing.MemoryCachedArray): + # don't share caching between copies + data = indexing.MemoryCachedArray(data.array) + + if deep: + if isinstance(data, dask_array_type): + data = data.copy() + elif not isinstance(data, PandasIndexAdapter): + # pandas.Index is immutable + data = np.array(data) + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) # note: # dims is already an immutable tuple @@ -1709,14 +1766,37 @@ def concat(cls, variables, dim='concat_dim', positions=None, return cls(first_var.dims, data, attrs) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. - `deep` is ignored since data is stored in the form of pandas.Index, - which is already immutable. Dimensions, attributes and encodings are - always copied. + `deep` is ignored since data is stored in the form of + pandas.Index, which is already immutable. Dimensions, attributes + and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Deep is always ignored. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. """ - return type(self)(self.dims, self._data, self._attrs, + if data is None: + data = self._data + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def equals(self, other, equiv=None): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a4562894583..2b93e696d50 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3138,6 +3138,18 @@ def test_roll_coords_none(self): expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) assert_identical(expected, actual) + def test_copy_with_data(self): + orig = DataArray(np.random.random(size=(2, 2)), + dims=('x', 'y'), + attrs={'attr1': 'value1'}, + coords={'x': [4, 3]}, + name='helloworld') + new_data = np.arange(4).reshape(2, 2) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + def test_real_and_imag(self): array = DataArray(1 + 2j) assert_identical(array.real, DataArray(1)) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d22d8470dc6..fc933960914 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1892,6 +1892,27 @@ def test_copy(self): v1 = copied.variables[k] assert v0 is not v1 + def test_copy_with_data(self): + orig = create_test_data() + new_data = {k: np.random.randn(*v.shape) + for k, v in iteritems(orig.data_vars)} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = create_test_data() + new_var1 = np.arange(orig['var1'].size).reshape(orig['var1'].shape) + with raises_regex(ValueError, 'Data must be dict-like'): + orig.copy(data=new_var1) + with raises_regex(ValueError, 'only contain variables in original'): + orig.copy(data={'not_in_original': new_var1}) + with raises_regex(ValueError, 'contain all variables in original'): + orig.copy(data={'var1': new_var1}) + def test_rename(self): data = create_test_data() newnames = {'var1': 'renamed_var1', 'dim2': 'renamed_dim2'} diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 904940cbbf6..1263ac1df9e 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -505,6 +505,34 @@ def test_copy_index(self): assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) + def test_copy_with_data(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = np.array([[2.5, 5.0], [7.1, 43]]) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = [2.5, 5.0] + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + + def test_copy_index_with_data(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 10) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_index_with_data_errors(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 20) + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + def test_real_and_imag(self): v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'}) expected_re = self.cls('x', np.arange(3), {'foo': 'bar'}) From a0b5af5a1945ccac3704df0ff2acaf55f2db2de6 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Tue, 18 Sep 2018 22:59:27 -0700 Subject: [PATCH 22/51] Update NumFOCUS donate link (#2421) * add some blurbs about numfocus sponsorship to docs * add numfocus to history blurb * add missing logo file * markdown to rst in readme * update numfocus donate links in documentation * more links to numfocus --- README.rst | 5 +++-- doc/index.rst | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index 12650b1db1b..0ac71d33954 100644 --- a/README.rst +++ b/README.rst @@ -110,13 +110,14 @@ NumFOCUS .. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png :scale: 25 % + :target: https://numfocus.org/ -Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated to supporting the open source scientific computing community. If you like Xarray and want to support our mission, please consider making a donation_ to support our efforts. -.. _donation: https://www.flipcause.com/secure/cause_pdetails/MjE3OQ== +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= History ------- diff --git a/doc/index.rst b/doc/index.rst index 6c1a8519507..45897f4bccb 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -125,13 +125,16 @@ NumFOCUS .. image:: _static/numfocus_logo.png :scale: 50 % + :target: https://numfocus.org/ -Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated to supporting the open source scientific computing community. If you like -Xarray and want to support our mission, please consider making a -[donation](https://www.flipcause.com/secure/cause_pdetails/MjE3OQ==) +Xarray and want to support our mission, please consider making a donation_ to support our efforts. +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= + + History ------- From 5b87b6e2f159b827f739e12d4faae57a0b6f6178 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Wed, 19 Sep 2018 16:24:39 -0400 Subject: [PATCH 23/51] WIP Add a CFTimeIndex-enabled xr.cftime_range function (#2301) * Initial work on adding a cftime-compatible date_range function Add docstring for xr.date_range Fix failing test Fix test skipping logic Coerce result of zip to a list in test setup Add and clean up tests Fix skip logic Skip roll_forward and roll_backward tests if cftime is not installed Expose all possible arguments to pd.date_range Add more detail to docstrings flake8 Add a what's new entry Add a short example to time-series.rst * Allow empty CFTimeIndexes; add calendar to CFTimeIndex repr * Enable CFTimeIndex constructor to optionally take date_range arguments * Simplify date_range to use new CFTimeIndex constructor * Rename xr.date_range to xr.cftime_range * Follow pandas behavior/naming for rollforward, rollback, and onOffset * Update docstring * Add pandas copyright notice to cftime_offsets.py * Check validity of offset constructor arguments * Fix TypeError versus ValueError uses * Use a module-level importorskip in test_cftime_offsets.py * Only return a CFTimeIndex from cftime_range * Keep CFTimeIndex constructor simple * Add some explicitly calendar-specific tests * Revert back to default repr * lint * return NotImplemented * Convert pandas copyright notices to comments * test_calendar_leap_year_length -> test_calendar_year_length * Use return NotImplemented in __apply__ too --- doc/api.rst | 7 + doc/time-series.rst | 11 +- doc/whats-new.rst | 3 + xarray/__init__.py | 1 + xarray/coding/cftime_offsets.py | 736 +++++++++++++++++++++++++ xarray/coding/cftimeindex.py | 88 ++- xarray/tests/test_cftime_offsets.py | 801 ++++++++++++++++++++++++++++ xarray/tests/test_cftimeindex.py | 39 +- 8 files changed, 1665 insertions(+), 21 deletions(-) create mode 100644 xarray/coding/cftime_offsets.py create mode 100644 xarray/tests/test_cftime_offsets.py diff --git a/doc/api.rst b/doc/api.rst index 927c0aa072c..89fee10506d 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -555,6 +555,13 @@ Custom Indexes CFTimeIndex +Creating custom indexes +----------------------- +.. autosummary:: + :toctree: generated/ + + cftime_range + Plotting ======== diff --git a/doc/time-series.rst b/doc/time-series.rst index a7ce9226d4d..d99c3218d18 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -258,7 +258,16 @@ coordinate with a no-leap calendar within a context manager setting the calendar, its times will be decoded into ``cftime.datetime`` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. - + +xarray also includes a :py:func:`cftime_range` function, which enables creating a +``CFTimeIndex`` with regularly-spaced dates. For instance, we can create the +same dates and DataArray we created above using: + +.. ipython:: python + + dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + For data indexed by a ``CFTimeIndex`` xarray currently supports: - `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b16533ad33b..8c34ddf3fa9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -54,6 +54,9 @@ Enhancements now displayed as `a b ... y z` rather than `a b c d ...`. (:issue:`1186`) By `Seth P `_. +- A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in + generating dates from standard or non-standard calendars. By `Spencer Clark + `_. - When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')`` (:issue:`2284`) diff --git a/xarray/__init__.py b/xarray/__init__.py index 7cc7811b783..e2d24e6c294 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -26,6 +26,7 @@ from .conventions import decode_cf, SerializationWarning +from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex from .util.print_versions import show_versions diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py new file mode 100644 index 00000000000..3fbb44f4ed3 --- /dev/null +++ b/xarray/coding/cftime_offsets.py @@ -0,0 +1,736 @@ +"""Time offset classes for use with cftime.datetime objects""" +# The offset classes and mechanisms for generating time ranges defined in +# this module were copied/adapted from those defined in pandas. See in +# particular the objects and methods defined in pandas.tseries.offsets +# and pandas.core.indexes.datetimes. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re + +from datetime import timedelta +from functools import partial + +import numpy as np + +from .cftimeindex import _parse_iso8601_with_reso, CFTimeIndex +from .times import format_cftime_datetime +from ..core.pycompat import basestring + + +def get_date_type(calendar): + """Return the cftime date type for a given calendar name.""" + try: + import cftime + except ImportError: + raise ImportError( + 'cftime is required for dates with non-standard calendars') + else: + calendars = { + 'noleap': cftime.DatetimeNoLeap, + '360_day': cftime.Datetime360Day, + '365_day': cftime.DatetimeNoLeap, + '366_day': cftime.DatetimeAllLeap, + 'gregorian': cftime.DatetimeGregorian, + 'proleptic_gregorian': cftime.DatetimeProlepticGregorian, + 'julian': cftime.DatetimeJulian, + 'all_leap': cftime.DatetimeAllLeap, + 'standard': cftime.DatetimeProlepticGregorian + } + return calendars[calendar] + + +class BaseCFTimeOffset(object): + _freq = None + + def __init__(self, n=1): + if not isinstance(n, int): + raise TypeError( + "The provided multiple 'n' must be an integer. " + "Instead a value of type {!r} was provided.".format(type(n))) + self.n = n + + def rule_code(self): + return self._freq + + def __eq__(self, other): + return self.n == other.n and self.rule_code() == other.rule_code() + + def __ne__(self, other): + return not self == other + + def __add__(self, other): + return self.__apply__(other) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract a cftime.datetime ' + 'from a time offset.') + elif type(other) == type(self): + return type(self)(self.n - other.n) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n) + + def __neg__(self): + return self * -1 + + def __rmul__(self, other): + return self.__mul__(other) + + def __radd__(self, other): + return self.__add__(other) + + def __rsub__(self, other): + if isinstance(other, BaseCFTimeOffset) and type(self) != type(other): + raise TypeError('Cannot subtract cftime offsets of differing ' + 'types') + return -self + other + + def __apply__(self): + return NotImplemented + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + test_date = (self + date) - self + return date == test_date + + def rollforward(self, date): + if self.onOffset(date): + return date + else: + return date + type(self)() + + def rollback(self, date): + if self.onOffset(date): + return date + else: + return date - type(self)() + + def __str__(self): + return '<{}: n={}>'.format(type(self).__name__, self.n) + + def __repr__(self): + return str(self) + + +def _days_in_month(date): + """The number of days in the month of the given date""" + if date.month == 12: + reference = type(date)(date.year + 1, 1, 1) + else: + reference = type(date)(date.year, date.month + 1, 1) + return (reference - timedelta(days=1)).day + + +def _adjust_n_months(other_day, n, reference_day): + """Adjust the number of times a monthly offset is applied based + on the day of a given date, and the reference day provided. + """ + if n > 0 and other_day < reference_day: + n = n - 1 + elif n <= 0 and other_day > reference_day: + n = n + 1 + return n + + +def _adjust_n_years(other, n, month, reference_day): + """Adjust the number of times an annual offset is applied based on + another date, and the reference day provided""" + if n > 0: + if other.month < month or (other.month == month and + other.day < reference_day): + n -= 1 + else: + if other.month > month or (other.month == month and + other.day > reference_day): + n += 1 + return n + + +def _shift_months(date, months, day_option='start'): + """Shift the date to a month start or end a given number of months away. + """ + delta_year = (date.month + months) // 12 + month = (date.month + months) % 12 + + if month == 0: + month = 12 + delta_year = delta_year - 1 + year = date.year + delta_year + + if day_option == 'start': + day = 1 + elif day_option == 'end': + reference = type(date)(year, month, 1) + day = _days_in_month(reference) + else: + raise ValueError(day_option) + return date.replace(year=year, month=month, day=day) + + +class MonthBegin(BaseCFTimeOffset): + _freq = 'MS' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, 1) + return _shift_months(other, n, 'start') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 + + +class MonthEnd(BaseCFTimeOffset): + _freq = 'M' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, _days_in_month(other)) + return _shift_months(other, n, 'end') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) + + +_MONTH_ABBREVIATIONS = { + 1: 'JAN', + 2: 'FEB', + 3: 'MAR', + 4: 'APR', + 5: 'MAY', + 6: 'JUN', + 7: 'JUL', + 8: 'AUG', + 9: 'SEP', + 10: 'OCT', + 11: 'NOV', + 12: 'DEC' +} + + +class YearOffset(BaseCFTimeOffset): + _freq = None + _day_option = None + _default_month = None + + def __init__(self, n=1, month=None): + BaseCFTimeOffset.__init__(self, n) + if month is None: + self.month = self._default_month + else: + self.month = month + if not isinstance(self.month, int): + raise TypeError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + elif not (1 <= self.month <= 12): + raise ValueError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + + def __apply__(self, other): + if self._day_option == 'start': + reference_day = 1 + elif self._day_option == 'end': + reference_day = _days_in_month(other) + else: + raise ValueError(self._day_option) + years = _adjust_n_years(other, self.n, self.month, reference_day) + months = years * 12 + (self.month - other.month) + return _shift_months(other, months, self._day_option) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract cftime.datetime from offset.') + elif type(other) == type(self) and other.month == self.month: + return type(self)(self.n - other.n, month=self.month) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n, month=self.month) + + def rule_code(self): + return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month]) + + def __str__(self): + return '<{}: n={}, month={}>'.format( + type(self).__name__, self.n, self.month) + + +class YearBegin(YearOffset): + _freq = 'AS' + _day_option = 'start' + _default_month = 1 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date + YearBegin(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date - YearBegin(month=self.month) + + +class YearEnd(YearOffset): + _freq = 'A' + _day_option = 'end' + _default_month = 12 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date + YearEnd(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date - YearEnd(month=self.month) + + +class Day(BaseCFTimeOffset): + _freq = 'D' + + def __apply__(self, other): + return other + timedelta(days=self.n) + + +class Hour(BaseCFTimeOffset): + _freq = 'H' + + def __apply__(self, other): + return other + timedelta(hours=self.n) + + +class Minute(BaseCFTimeOffset): + _freq = 'T' + + def __apply__(self, other): + return other + timedelta(minutes=self.n) + + +class Second(BaseCFTimeOffset): + _freq = 'S' + + def __apply__(self, other): + return other + timedelta(seconds=self.n) + + +_FREQUENCIES = { + 'A': YearEnd, + 'AS': YearBegin, + 'Y': YearEnd, + 'YS': YearBegin, + 'M': MonthEnd, + 'MS': MonthBegin, + 'D': Day, + 'H': Hour, + 'T': Minute, + 'min': Minute, + 'S': Second, + 'AS-JAN': partial(YearBegin, month=1), + 'AS-FEB': partial(YearBegin, month=2), + 'AS-MAR': partial(YearBegin, month=3), + 'AS-APR': partial(YearBegin, month=4), + 'AS-MAY': partial(YearBegin, month=5), + 'AS-JUN': partial(YearBegin, month=6), + 'AS-JUL': partial(YearBegin, month=7), + 'AS-AUG': partial(YearBegin, month=8), + 'AS-SEP': partial(YearBegin, month=9), + 'AS-OCT': partial(YearBegin, month=10), + 'AS-NOV': partial(YearBegin, month=11), + 'AS-DEC': partial(YearBegin, month=12), + 'A-JAN': partial(YearEnd, month=1), + 'A-FEB': partial(YearEnd, month=2), + 'A-MAR': partial(YearEnd, month=3), + 'A-APR': partial(YearEnd, month=4), + 'A-MAY': partial(YearEnd, month=5), + 'A-JUN': partial(YearEnd, month=6), + 'A-JUL': partial(YearEnd, month=7), + 'A-AUG': partial(YearEnd, month=8), + 'A-SEP': partial(YearEnd, month=9), + 'A-OCT': partial(YearEnd, month=10), + 'A-NOV': partial(YearEnd, month=11), + 'A-DEC': partial(YearEnd, month=12) +} + + +_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys()) +_PATTERN = '^((?P\d+)|())(?P({0}))$'.format( + _FREQUENCY_CONDITION) + + +def to_offset(freq): + """Convert a frequency string to the appropriate subclass of + BaseCFTimeOffset.""" + if isinstance(freq, BaseCFTimeOffset): + return freq + else: + try: + freq_data = re.match(_PATTERN, freq).groupdict() + except AttributeError: + raise ValueError('Invalid frequency string provided') + + freq = freq_data['freq'] + multiples = freq_data['multiple'] + if multiples is None: + multiples = 1 + else: + multiples = int(multiples) + + return _FREQUENCIES[freq](n=multiples) + + +def to_cftime_datetime(date_str_or_date, calendar=None): + import cftime + + if isinstance(date_str_or_date, basestring): + if calendar is None: + raise ValueError( + 'If converting a string to a cftime.datetime object, ' + 'a calendar type must be provided') + date, _ = _parse_iso8601_with_reso(get_date_type(calendar), + date_str_or_date) + return date + elif isinstance(date_str_or_date, cftime.datetime): + return date_str_or_date + else: + raise TypeError("date_str_or_date must be a string or a " + 'subclass of cftime.datetime. Instead got ' + '{!r}.'.format(date_str_or_date)) + + +def normalize_date(date): + """Round datetime down to midnight.""" + return date.replace(hour=0, minute=0, second=0, microsecond=0) + + +def _maybe_normalize_date(date, normalize): + """Round datetime down to midnight if normalize is True.""" + if normalize: + return normalize_date(date) + else: + return date + + +def _generate_linear_range(start, end, periods): + """Generate an equally-spaced sequence of cftime.datetime objects between + and including two dates (whose length equals the number of periods).""" + import cftime + + total_seconds = (end - start).total_seconds() + values = np.linspace(0., total_seconds, periods, endpoint=True) + units = 'seconds since {}'.format(format_cftime_datetime(start)) + calendar = start.calendar + return cftime.num2date(values, units=units, calendar=calendar, + only_use_cftime_datetimes=True) + + +def _generate_range(start, end, periods, offset): + """Generate a regular range of cftime.datetime objects with a + given time offset. + + Adapted from pandas.tseries.offsets.generate_range. + + Parameters + ---------- + start : cftime.datetime, or None + Start of range + end : cftime.datetime, or None + End of range + periods : int, or None + Number of elements in the sequence + offset : BaseCFTimeOffset + An offset class designed for working with cftime.datetime objects + + Returns + ------- + A generator object + """ + if start: + start = offset.rollforward(start) + + if end: + end = offset.rollback(end) + + if periods is None and end < start: + end = None + periods = 0 + + if end is None: + end = start + (periods - 1) * offset + + if start is None: + start = end - (periods - 1) * offset + + current = start + if offset.n >= 0: + while current <= end: + yield current + + next_date = current + offset + if next_date <= current: + raise ValueError('Offset {offset} did not increment date' + .format(offset=offset)) + current = next_date + else: + while current >= end: + yield current + + next_date = current + offset + if next_date >= current: + raise ValueError('Offset {offset} did not decrement date' + .format(offset=offset)) + current = next_date + + +def _count_not_none(*args): + """Compute the number of non-None arguments.""" + return sum([arg is not None for arg in args]) + + +def cftime_range(start=None, end=None, periods=None, freq='D', + tz=None, normalize=False, name=None, closed=None, + calendar='standard'): + """Return a fixed frequency CFTimeIndex. + + Parameters + ---------- + start : str or cftime.datetime, optional + Left bound for generating dates. + end : str or cftime.datetime, optional + Right bound for generating dates. + periods : integer, optional + Number of periods to generate. + freq : str, default 'D', BaseCFTimeOffset, or None + Frequency strings can have multiples, e.g. '5H'. + normalize : bool, default False + Normalize start/end dates to midnight before generating date range. + name : str, default None + Name of the resulting index + closed : {None, 'left', 'right'}, optional + Make the interval closed with respect to the given frequency to the + 'left', 'right', or both sides (None, the default). + calendar : str + Calendar type for the datetimes (default 'standard'). + + Returns + ------- + CFTimeIndex + + Notes + ----- + + This function is an analog of ``pandas.date_range`` for use in generating + sequences of ``cftime.datetime`` objects. It supports most of the + features of ``pandas.date_range`` (e.g. specifying how the index is + ``closed`` on either side, or whether or not to ``normalize`` the start and + end bounds); however, there are some notable exceptions: + + - You cannot specify a ``tz`` (time zone) argument. + - Start or end dates specified as partial-datetime strings must use the + `ISO-8601 format `_. + - It supports many, but not all, frequencies supported by + ``pandas.date_range``. For example it does not currently support any of + the business-related, semi-monthly, or sub-second frequencies. + - Compound sub-monthly frequencies are not supported, e.g. '1H1min', as + these can easily be written in terms of the finest common resolution, + e.g. '61min'. + + Valid simple frequency strings for use with ``cftime``-calendars include + any multiples of the following. + + +--------+-----------------------+ + | Alias | Description | + +========+=======================+ + | A, Y | Year-end frequency | + +--------+-----------------------+ + | AS, YS | Year-start frequency | + +--------+-----------------------+ + | M | Month-end frequency | + +--------+-----------------------+ + | MS | Month-start frequency | + +--------+-----------------------+ + | D | Day frequency | + +--------+-----------------------+ + | H | Hour frequency | + +--------+-----------------------+ + | T, min | Minute frequency | + +--------+-----------------------+ + | S | Second frequency | + +--------+-----------------------+ + + Any multiples of the following anchored offsets are also supported. + + +----------+-------------------------------------------------------------------+ + | Alias | Description | + +==========+===================================================================+ + | A(S)-JAN | Annual frequency, anchored at the end (or beginning) of January | + +----------+-------------------------------------------------------------------+ + | A(S)-FEB | Annual frequency, anchored at the end (or beginning) of February | + +----------+-------------------------------------------------------------------+ + | A(S)-MAR | Annual frequency, anchored at the end (or beginning) of March | + +----------+-------------------------------------------------------------------+ + | A(S)-APR | Annual frequency, anchored at the end (or beginning) of April | + +----------+-------------------------------------------------------------------+ + | A(S)-MAY | Annual frequency, anchored at the end (or beginning) of May | + +----------+-------------------------------------------------------------------+ + | A(S)-JUN | Annual frequency, anchored at the end (or beginning) of June | + +----------+-------------------------------------------------------------------+ + | A(S)-JUL | Annual frequency, anchored at the end (or beginning) of July | + +----------+-------------------------------------------------------------------+ + | A(S)-AUG | Annual frequency, anchored at the end (or beginning) of August | + +----------+-------------------------------------------------------------------+ + | A(S)-SEP | Annual frequency, anchored at the end (or beginning) of September | + +----------+-------------------------------------------------------------------+ + | A(S)-OCT | Annual frequency, anchored at the end (or beginning) of October | + +----------+-------------------------------------------------------------------+ + | A(S)-NOV | Annual frequency, anchored at the end (or beginning) of November | + +----------+-------------------------------------------------------------------+ + | A(S)-DEC | Annual frequency, anchored at the end (or beginning) of December | + +----------+-------------------------------------------------------------------+ + + Finally, the following calendar aliases are supported. + + +--------------------------------+---------------------------------------+ + | Alias | Date type | + +================================+=======================================+ + | standard, proleptic_gregorian | ``cftime.DatetimeProlepticGregorian`` | + +--------------------------------+---------------------------------------+ + | gregorian | ``cftime.DatetimeGregorian`` | + +--------------------------------+---------------------------------------+ + | noleap, 365_day | ``cftime.DatetimeNoLeap`` | + +--------------------------------+---------------------------------------+ + | all_leap, 366_day | ``cftime.DatetimeAllLeap`` | + +--------------------------------+---------------------------------------+ + | 360_day | ``cftime.Datetime360Day`` | + +--------------------------------+---------------------------------------+ + | julian | ``cftime.DatetimeJulian`` | + +--------------------------------+---------------------------------------+ + + Examples + -------- + + This function returns a ``CFTimeIndex``, populated with ``cftime.datetime`` + objects associated with the specified calendar type, e.g. + + >>> xr.cftime_range(start='2000', periods=6, freq='2MS', calendar='noleap') + CFTimeIndex([2000-01-01 00:00:00, 2000-03-01 00:00:00, 2000-05-01 00:00:00, + 2000-07-01 00:00:00, 2000-09-01 00:00:00, 2000-11-01 00:00:00], + dtype='object') + + As in the standard pandas function, three of the ``start``, ``end``, + ``periods``, or ``freq`` arguments must be specified at a given time, with + the other set to ``None``. See the `pandas documentation + `_ + for more examples of the behavior of ``date_range`` with each of the + parameters. + + See Also + -------- + pandas.date_range + """ # noqa: E501 + # Adapted from pandas.core.indexes.datetimes._generate_range. + if _count_not_none(start, end, periods, freq) != 3: + raise ValueError( + "Of the arguments 'start', 'end', 'periods', and 'freq', three " + "must be specified at a time.") + + if start is not None: + start = to_cftime_datetime(start, calendar) + start = _maybe_normalize_date(start, normalize) + if end is not None: + end = to_cftime_datetime(end, calendar) + end = _maybe_normalize_date(end, normalize) + + if freq is None: + dates = _generate_linear_range(start, end, periods) + else: + offset = to_offset(freq) + dates = np.array(list(_generate_range(start, end, periods, offset))) + + left_closed = False + right_closed = False + + if closed is None: + left_closed = True + right_closed = True + elif closed == 'left': + left_closed = True + elif closed == 'right': + right_closed = True + else: + raise ValueError("Closed must be either 'left', 'right' or None") + + if (not left_closed and len(dates) and + start is not None and dates[0] == start): + dates = dates[1:] + if (not right_closed and len(dates) and + end is not None and dates[-1] == end): + dates = dates[:-1] + + return CFTimeIndex(dates, name=name) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index eb8cae2f398..ea2bcbc5858 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -1,3 +1,44 @@ +"""DatetimeIndex analog for cftime.datetime objects""" +# The pandas.Index subclass defined here was copied and adapted for +# use with cftime.datetime objects based on the source code defining +# pandas.DatetimeIndex. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from __future__ import absolute_import import re from datetime import timedelta @@ -116,28 +157,43 @@ def f(self): def get_date_type(self): - return type(self._data[0]) + if self.data: + return type(self._data[0]) + else: + return None def assert_all_valid_date_type(data): import cftime - sample = data[0] - date_type = type(sample) - if not isinstance(sample, cftime.datetime): - raise TypeError( - 'CFTimeIndex requires cftime.datetime ' - 'objects. Got object of {}.'.format(date_type)) - if not all(isinstance(value, date_type) for value in data): - raise TypeError( - 'CFTimeIndex requires using datetime ' - 'objects of all the same type. Got\n{}.'.format(data)) + if data.size: + sample = data[0] + date_type = type(sample) + if not isinstance(sample, cftime.datetime): + raise TypeError( + 'CFTimeIndex requires cftime.datetime ' + 'objects. Got object of {}.'.format(date_type)) + if not all(isinstance(value, date_type) for value in data): + raise TypeError( + 'CFTimeIndex requires using datetime ' + 'objects of all the same type. Got\n{}.'.format(data)) class CFTimeIndex(pd.Index): """Custom Index for working with CF calendars and dates All elements of a CFTimeIndex must be cftime.datetime objects. + + Parameters + ---------- + data : array or CFTimeIndex + Sequence of cftime.datetime objects to use in index + name : str, default None + Name of the resulting index + + See Also + -------- + cftime_range """ year = _field_accessor('year', 'The year of the datetime') month = _field_accessor('month', 'The month of the datetime') @@ -149,10 +205,14 @@ class CFTimeIndex(pd.Index): 'The microseconds of the datetime') date_type = property(get_date_type) - def __new__(cls, data): + def __new__(cls, data, name=None): + if name is None and hasattr(data, 'name'): + name = data.name + result = object.__new__(cls) - assert_all_valid_date_type(data) - result._data = np.array(data) + result._data = np.array(data, dtype='O') + assert_all_valid_date_type(result._data) + result.name = name return result def _partial_date_slice(self, resolution, parsed): diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py new file mode 100644 index 00000000000..6d7990689ed --- /dev/null +++ b/xarray/tests/test_cftime_offsets.py @@ -0,0 +1,801 @@ +import pytest + +from itertools import product + +import numpy as np + +from xarray.coding.cftime_offsets import ( + BaseCFTimeOffset, YearBegin, YearEnd, MonthBegin, MonthEnd, + Day, Hour, Minute, Second, _days_in_month, + to_offset, get_date_type, _MONTH_ABBREVIATIONS, to_cftime_datetime, + cftime_range) +from xarray import CFTimeIndex + +cftime = pytest.importorskip('cftime') + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian', 'standard'] + + +def _id_func(param): + """Called on each parameter passed to pytest.mark.parametrize""" + return str(param) + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.mark.parametrize( + ('offset', 'expected_n'), + [(BaseCFTimeOffset(), 1), + (YearBegin(), 1), + (YearEnd(), 1), + (BaseCFTimeOffset(n=2), 2), + (YearBegin(n=2), 2), + (YearEnd(n=2), 2)], + ids=_id_func +) +def test_cftime_offset_constructor_valid_n(offset, expected_n): + assert offset.n == expected_n + + +@pytest.mark.parametrize( + ('offset', 'invalid_n'), + [(BaseCFTimeOffset, 1.5), + (YearBegin, 1.5), + (YearEnd, 1.5)], + ids=_id_func +) +def test_cftime_offset_constructor_invalid_n(offset, invalid_n): + with pytest.raises(TypeError): + offset(n=invalid_n) + + +@pytest.mark.parametrize( + ('offset', 'expected_month'), + [(YearBegin(), 1), + (YearEnd(), 12), + (YearBegin(month=5), 5), + (YearEnd(month=5), 5)], + ids=_id_func +) +def test_year_offset_constructor_valid_month(offset, expected_month): + assert offset.month == expected_month + + +@pytest.mark.parametrize( + ('offset', 'invalid_month', 'exception'), + [(YearBegin, 0, ValueError), + (YearEnd, 0, ValueError), + (YearBegin, 13, ValueError,), + (YearEnd, 13, ValueError), + (YearBegin, 1.5, TypeError), + (YearEnd, 1.5, TypeError)], + ids=_id_func +) +def test_year_offset_constructor_invalid_month( + offset, invalid_month, exception): + with pytest.raises(exception): + offset(month=invalid_month) + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), None), + (MonthBegin(), 'MS'), + (YearBegin(), 'AS-JAN')], + ids=_id_func +) +def test_rule_code(offset, expected): + assert offset.rule_code() == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), ''), + (YearBegin(), '')], + ids=_id_func +) +def test_str_and_repr(offset, expected): + assert str(offset) == expected + assert repr(offset) == expected + + +@pytest.mark.parametrize( + 'offset', + [BaseCFTimeOffset(), MonthBegin(), YearBegin()], + ids=_id_func +) +def test_to_offset_offset_input(offset): + assert to_offset(offset) == offset + + +@pytest.mark.parametrize( + ('freq', 'expected'), + [('M', MonthEnd()), + ('2M', MonthEnd(n=2)), + ('MS', MonthBegin()), + ('2MS', MonthBegin(n=2)), + ('D', Day()), + ('2D', Day(n=2)), + ('H', Hour()), + ('2H', Hour(n=2)), + ('T', Minute()), + ('2T', Minute(n=2)), + ('min', Minute()), + ('2min', Minute(n=2)), + ('S', Second()), + ('2S', Second(n=2))], + ids=_id_func +) +def test_to_offset_sub_annual(freq, expected): + assert to_offset(freq) == expected + + +_ANNUAL_OFFSET_TYPES = { + 'A': YearEnd, + 'AS': YearBegin +} + + +@pytest.mark.parametrize(('month_int', 'month_label'), + list(_MONTH_ABBREVIATIONS.items()) + [('', '')]) +@pytest.mark.parametrize('multiple', [None, 2]) +@pytest.mark.parametrize('offset_str', ['AS', 'A']) +def test_to_offset_annual(month_label, month_int, multiple, offset_str): + freq = offset_str + offset_type = _ANNUAL_OFFSET_TYPES[offset_str] + if month_label: + freq = '-'.join([freq, month_label]) + if multiple: + freq = '{}'.format(multiple) + freq + result = to_offset(freq) + + if multiple and month_int: + expected = offset_type(n=multiple, month=month_int) + elif multiple: + expected = offset_type(n=multiple) + elif month_int: + expected = offset_type(month=month_int) + else: + expected = offset_type() + assert result == expected + + +@pytest.mark.parametrize('freq', ['Z', '7min2', 'AM', 'M-', 'AS-', '1H1min']) +def test_invalid_to_offset_str(freq): + with pytest.raises(ValueError): + to_offset(freq) + + +@pytest.mark.parametrize( + ('argument', 'expected_date_args'), + [('2000-01-01', (2000, 1, 1)), + ((2000, 1, 1), (2000, 1, 1))], + ids=_id_func +) +def test_to_cftime_datetime(calendar, argument, expected_date_args): + date_type = get_date_type(calendar) + expected = date_type(*expected_date_args) + if isinstance(argument, tuple): + argument = date_type(*argument) + result = to_cftime_datetime(argument, calendar=calendar) + assert result == expected + + +def test_to_cftime_datetime_error_no_calendar(): + with pytest.raises(ValueError): + to_cftime_datetime('2000') + + +def test_to_cftime_datetime_error_type_error(): + with pytest.raises(TypeError): + to_cftime_datetime(1) + + +_EQ_TESTS_A = [ + BaseCFTimeOffset(), YearBegin(), YearEnd(), YearBegin(month=2), + YearEnd(month=2), MonthBegin(), MonthEnd(), Day(), Hour(), Minute(), + Second() +] +_EQ_TESTS_B = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func +) +def test_neq(a, b): + assert a != b + + +_EQ_TESTS_B_COPY = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func +) +def test_eq(a, b): + assert a == b + + +_MUL_TESTS = [ + (BaseCFTimeOffset(), BaseCFTimeOffset(n=3)), + (YearEnd(), YearEnd(n=3)), + (YearBegin(), YearBegin(n=3)), + (MonthEnd(), MonthEnd(n=3)), + (MonthBegin(), MonthBegin(n=3)), + (Day(), Day(n=3)), + (Hour(), Hour(n=3)), + (Minute(), Minute(n=3)), + (Second(), Second(n=3)) +] + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_mul(offset, expected): + assert offset * 3 == expected + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_rmul(offset, expected): + assert 3 * offset == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), + (YearEnd(), YearEnd(n=-1)), + (YearBegin(), YearBegin(n=-1)), + (MonthEnd(), MonthEnd(n=-1)), + (MonthBegin(), MonthBegin(n=-1)), + (Day(), Day(n=-1)), + (Hour(), Hour(n=-1)), + (Minute(), Minute(n=-1)), + (Second(), Second(n=-1))], + ids=_id_func) +def test_neg(offset, expected): + assert -offset == expected + + +_ADD_TESTS = [ + (Day(n=2), (1, 1, 3)), + (Hour(n=2), (1, 1, 1, 2)), + (Minute(n=2), (1, 1, 1, 0, 2)), + (Second(n=2), (1, 1, 1, 0, 0, 2)) +] + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_add_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = offset + initial + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_radd_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = initial + offset + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + [(Day(n=2), (1, 1, 1)), + (Hour(n=2), (1, 1, 2, 22)), + (Minute(n=2), (1, 1, 2, 23, 58)), + (Second(n=2), (1, 1, 2, 23, 59, 58))], + ids=_id_func +) +def test_rsub_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 3) + expected = date_type(*expected_date_args) + result = initial - offset + assert result == expected + + +@pytest.mark.parametrize('offset', _EQ_TESTS_A, ids=_id_func) +def test_sub_error(offset, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + with pytest.raises(TypeError): + offset - initial + + +@pytest.mark.parametrize( + ('a', 'b'), + zip(_EQ_TESTS_A, _EQ_TESTS_B), + ids=_id_func +) +def test_minus_offset(a, b): + result = b - a + expected = a + assert result == expected + + +@pytest.mark.parametrize( + ('a', 'b'), + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + + [(YearEnd(month=1), YearEnd(month=2))], + ids=_id_func +) +def test_minus_offset_error(a, b): + with pytest.raises(TypeError): + b - a + + +def test_days_in_month_non_december(calendar): + date_type = get_date_type(calendar) + reference = date_type(1, 4, 1) + assert _days_in_month(reference) == 30 + + +def test_days_in_month_december(calendar): + if calendar == '360_day': + expected = 30 + else: + expected = 31 + date_type = get_date_type(calendar) + reference = date_type(1, 12, 5) + assert _days_in_month(reference) == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), MonthBegin(), (1, 2, 1)), + ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), + ((1, 1, 7), MonthBegin(), (1, 2, 1)), + ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), + ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), + ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), + ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), + ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), + ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), + ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), + ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_begin( + calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), MonthEnd(), (1, 1), ()), + ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), + ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), + ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), + ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1), (), MonthEnd(), (1, 2), ()), + ((1, 1), (), MonthEnd(n=2), (1, 3), ()), + ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), + ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), + ((1, 2), (), MonthEnd(n=14), (2, 4), ()), + ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), + ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), + ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), YearBegin(), (2, 1, 1)), + ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), + ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), + ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), + ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), + ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), + ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_begin(calendar, initial_date_args, offset, + expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), YearEnd(), (1, 12), ()), + ((1, 1, 1), YearEnd(n=2), (2, 12), ()), + ((1, 1, 1), YearEnd(month=1), (1, 1), ()), + ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), + ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 12), (), YearEnd(), (2, 12), ()), + ((1, 12), (), YearEnd(n=2), (3, 12), ()), + ((2, 12), (), YearEnd(n=-1), (1, 12), ()), + ((3, 12), (), YearEnd(n=-2), (1, 12), ()), + ((1, 1), (), YearEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), + ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +# Note for all sub-monthly offsets, pandas always returns True for onOffset +@pytest.mark.parametrize( + ('date_args', 'offset', 'expected'), + [((1, 1, 1), MonthBegin(), True), + ((1, 1, 1, 1), MonthBegin(), True), + ((1, 1, 5), MonthBegin(), False), + ((1, 1, 5), MonthEnd(), False), + ((1, 1, 1), YearBegin(), True), + ((1, 1, 1, 1), YearBegin(), True), + ((1, 1, 5), YearBegin(), False), + ((1, 12, 1), YearEnd(), False), + ((1, 1, 1), Day(), True), + ((1, 1, 1, 1), Day(), True), + ((1, 1, 1), Hour(), True), + ((1, 1, 1), Minute(), True), + ((1, 1, 1), Second(), True)], + ids=_id_func +) +def test_onOffset(calendar, date_args, offset, expected): + date_type = get_date_type(calendar) + date = date_type(*date_args) + result = offset.onOffset(date) + assert result == expected + + +@pytest.mark.parametrize( + ('year_month_args', 'sub_day_args', 'offset'), + [((1, 1), (), MonthEnd()), + ((1, 1), (1,), MonthEnd()), + ((1, 12), (), YearEnd()), + ((1, 1), (), YearEnd(month=1))], + ids=_id_func +) +def test_onOffset_month_or_year_end( + calendar, year_month_args, sub_day_args, offset): + date_type = get_date_type(calendar) + reference_args = year_month_args + (1,) + reference = date_type(*reference_args) + date_args = year_month_args + (_days_in_month(reference),) + sub_day_args + date = date_type(*date_args) + result = offset.onOffset(date) + assert result + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (2, 1)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (2, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(), (1, 3, 1), (1, 12)), + (YearEnd(n=2), (1, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 4)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 4)), + (MonthEnd(), (1, 3, 2), (1, 3)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (MonthEnd(n=2), (1, 3, 2), (1, 3)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollforward(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollforward(initial) + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), + (YearEnd(), (2, 3, 1), (1, 12)), + (YearEnd(n=2), (2, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (YearEnd(month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 3)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthEnd(), (1, 3, 2), (1, 2)), + (MonthEnd(n=2), (1, 3, 2), (1, 2)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollback(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollback(initial) + assert result == expected + + +_CFTIME_RANGE_TESTS = [ + ('0001-01-01', '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', '0001-01-04', None, 'D', 'left', False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3)]), + ('0001-01-01', '0001-01-04', None, 'D', 'right', False, + [(1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, True, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', None, 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + (None, '0001-01-04', 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), (1, 1, 4), None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-30', '0011-02-01', None, '3AS-JUN', None, False, + [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)]), + ('0001-01-04', '0001-01-01', None, 'D', None, False, + []), + ('0010', None, 4, YearBegin(n=-2), None, False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)]), + ('0001-01-01', '0001-01-04', 4, None, None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]) +] + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed', 'normalize', + 'expected_date_args'), + _CFTIME_RANGE_TESTS, ids=_id_func +) +def test_cftime_range( + start, end, periods, freq, closed, normalize, calendar, + expected_date_args): + date_type = get_date_type(calendar) + expected_dates = [date_type(*args) for args in expected_date_args] + + if isinstance(start, tuple): + start = date_type(*start) + if isinstance(end, tuple): + end = date_type(*end) + + result = cftime_range( + start=start, end=end, periods=periods, freq=freq, closed=closed, + normalize=normalize, calendar=calendar) + resulting_dates = result.values + + assert isinstance(result, CFTimeIndex) + + if freq is not None: + np.testing.assert_equal(resulting_dates, expected_dates) + else: + # If we create a linear range of dates using cftime.num2date + # we will not get exact round number dates. This is because + # datetime arithmetic in cftime is accurate approximately to + # 1 millisecond (see https://unidata.github.io/cftime/api.html). + deltas = resulting_dates - expected_dates + deltas = np.array([delta.total_seconds() for delta in deltas]) + assert np.max(np.abs(deltas)) < 0.001 + + +def test_cftime_range_name(): + result = cftime_range(start='2000', periods=4, name='foo') + assert result.name == 'foo' + + result = cftime_range(start='2000', periods=4) + assert result.name is None + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed'), + [(None, None, 5, 'A', None), + ('2000', None, None, 'A', None), + (None, '2000', None, 'A', None), + ('2000', '2001', None, None, None), + (None, None, None, None, None), + ('2000', '2001', None, 'A', 'up'), + ('2000', '2001', 5, 'A', None)] +) +def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): + with pytest.raises(ValueError): + cftime_range(start, end, periods, freq, closed=closed) + + +_CALENDAR_SPECIFIC_MONTH_END_TESTS = [ + ('2M', 'noleap', + [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'all_leap', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', '360_day', + [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ('2M', 'standard', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'gregorian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'julian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]) +] + + +@pytest.mark.parametrize( + ('freq', 'calendar', 'expected_month_day'), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func +) +def test_calendar_specific_month_end(freq, calendar, expected_month_day): + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start='2000-02', end='2001', freq=freq, calendar=calendar).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day] + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + ('calendar', 'start', 'end', 'expected_number_of_days'), + [('noleap', '2000', '2001', 365), + ('all_leap', '2000', '2001', 366), + ('360_day', '2000', '2001', 360), + ('standard', '2000', '2001', 366), + ('gregorian', '2000', '2001', 366), + ('julian', '2000', '2001', 366), + ('noleap', '2001', '2002', 365), + ('all_leap', '2001', '2002', 366), + ('360_day', '2001', '2002', 360), + ('standard', '2001', '2002', 365), + ('gregorian', '2001', '2002', 365), + ('julian', '2001', '2002', 365)] +) +def test_calendar_year_length( + calendar, start, end, expected_number_of_days): + result = cftime_range(start, end, freq='D', closed='left', + calendar=calendar) + assert len(result) == expected_number_of_days diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 6f102b60b9d..f72c6904f0e 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -2,6 +2,7 @@ import pytest +import numpy as np import pandas as pd import xarray as xr @@ -121,22 +122,42 @@ def dec_days(date_type): return 31 +@pytest.fixture +def index_with_name(date_type): + dates = [date_type(1, 1, 1), date_type(1, 2, 1), + date_type(2, 1, 1), date_type(2, 2, 1)] + return CFTimeIndex(dates, name='foo') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize( + ('name', 'expected_name'), + [('bar', 'bar'), + (None, 'foo')]) +def test_constructor_with_name(index_with_name, name, expected_name): + result = CFTimeIndex(index_with_name, name=name).name + assert result == expected_name + + @pytest.mark.skipif(not has_cftime, reason='cftime not installed') def test_assert_all_valid_date_type(date_type, index): import cftime if date_type is cftime.DatetimeNoLeap: - mixed_date_types = [date_type(1, 1, 1), - cftime.DatetimeAllLeap(1, 2, 1)] + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeAllLeap(1, 2, 1)]) else: - mixed_date_types = [date_type(1, 1, 1), - cftime.DatetimeNoLeap(1, 2, 1)] + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeNoLeap(1, 2, 1)]) with pytest.raises(TypeError): assert_all_valid_date_type(mixed_date_types) with pytest.raises(TypeError): - assert_all_valid_date_type([1, date_type(1, 1, 1)]) + assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)])) - assert_all_valid_date_type([date_type(1, 1, 1), date_type(1, 2, 1)]) + assert_all_valid_date_type( + np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) @pytest.mark.skipif(not has_cftime, reason='cftime not installed') @@ -589,3 +610,9 @@ def test_concat_cftimeindex(date_type, enable_cftimeindex): else: assert isinstance(da.indexes['time'], pd.Index) assert not isinstance(da.indexes['time'], CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_empty_cftimeindex(): + index = CFTimeIndex([]) + assert index.date_type is None From c1c576f75a2c4c2f8fad314c10588f45c5a5c573 Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Fri, 21 Sep 2018 19:36:20 +0200 Subject: [PATCH 24/51] Plotting: restore xyincrease kwarg default to True (#2425) --- xarray/plot/plot.py | 8 ++++++-- xarray/tests/test_plot.py | 22 ++++++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 10fca44b417..b92429b857d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -624,7 +624,7 @@ def _plot2d(plotfunc): @functools.wraps(plotfunc) def newplotfunc(darray, x=None, y=None, figsize=None, size=None, aspect=None, ax=None, row=None, col=None, - col_wrap=None, xincrease=None, yincrease=None, + col_wrap=None, xincrease=True, yincrease=True, add_colorbar=None, add_labels=True, vmin=None, vmax=None, cmap=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, colors=None, @@ -776,6 +776,10 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, raise ValueError("cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False.") + # origin kwarg overrides yincrease + if 'origin' in kwargs: + yincrease = None + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) @@ -794,7 +798,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, @functools.wraps(newplotfunc) def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, aspect=None, ax=None, row=None, col=None, col_wrap=None, - xincrease=None, yincrease=None, add_colorbar=None, + xincrease=True, yincrease=True, add_colorbar=None, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, subplot_kws=None, diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 15cb6af5fb1..e27f03630b7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -762,6 +762,24 @@ def test_nonnumeric_index_raises_typeerror(self): def test_can_pass_in_axis(self): self.pass_in_axis(self.plotmethod) + def test_xyincrease_defaults(self): + + # With default settings the axis must be ordered regardless + # of the coords order. + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], + [1, 2]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + # Inverted coords + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], + [2, 1]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + def test_xyincrease_false_changes_axes(self): self.plotmethod(xincrease=False, yincrease=False) xlim = plt.gca().get_xlim() @@ -1308,8 +1326,8 @@ def test_regression_rgb_imshow_dim_size_one(self): da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) da.plot.imshow() - def test_imshow_origin_kwarg(self): - da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + def test_origin_overrides_xyincrease(self): + da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) da.plot.imshow(origin='upper') assert plt.xlim()[0] < 0 assert plt.ylim()[1] < 0 From ab96954883200f764a0dd50870e4db240c119265 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Sat, 22 Sep 2018 05:02:42 +0900 Subject: [PATCH 25/51] implement Gradient (#2398) * Added xr.gradient, DataArray.gradient, Dataset.gradient * Working with np.backend * test is not passing * Docs * flake8 * support environment without dask * Support numpy < 1.13 * Support numpy 1.12 * simplify dask.gradient * lint * Use npcompat.gradient in tests * move gradient to dask_array_compat * gradient -> differentiate * lint * Update dask_array_compat * Added a link from diff * remove xr.differentiate * Added datetime support * Update via comment. Use utils.to_numeric also in interp * time_unit -> datetime_unit * Some more info in docs. * update test * Update via comments * Update docs. --- doc/api.rst | 2 + doc/computation.rst | 23 ++++ doc/whats-new.rst | 4 + xarray/__init__.py | 2 +- xarray/core/dask_array_compat.py | 124 +++++++++++++++++++ xarray/core/dataarray.py | 58 +++++++++ xarray/core/dataset.py | 65 +++++++++- xarray/core/duck_array_ops.py | 8 ++ xarray/core/missing.py | 6 +- xarray/core/npcompat.py | 185 ++++++++++++++++++++++++++++ xarray/core/utils.py | 21 ++++ xarray/tests/test_dataset.py | 77 +++++++++++- xarray/tests/test_duck_array_ops.py | 21 +++- 13 files changed, 586 insertions(+), 10 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 89fee10506d..d204fab3539 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -150,6 +150,7 @@ Computation Dataset.resample Dataset.diff Dataset.quantile + Dataset.differentiate **Aggregation**: :py:attr:`~Dataset.all` @@ -317,6 +318,7 @@ Computation DataArray.diff DataArray.dot DataArray.quantile + DataArray.differentiate **Aggregation**: :py:attr:`~DataArray.all` diff --git a/doc/computation.rst b/doc/computation.rst index 6793e667e06..67cda6f2191 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -200,6 +200,29 @@ You can also use ``construct`` to compute a weighted rolling sum: To avoid this, use ``skipna=False`` as the above example. +Computation using Coordinates +============================= + +Xarray objects have some handy methods for the computation with their +coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by +central finite differences using their coordinates, + +.. ipython:: python + a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[0.1, 0.11, 0.2, 0.3]) + a + a.differentiate('x') + +This method can be used also for multidimensional arrays, + +.. ipython:: python + a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'], + coords=[0.1, 0.11, 0.2, 0.3]) + a.differentiate('x') + +.. note:: + This method is limited to simple cartesian geometry. Differentiation along + multidimensional coordinate is not supported. + .. _compute.broadcasting: Broadcasting by dimension name diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8c34ddf3fa9..7240059bd10 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,10 @@ Documentation Enhancements ~~~~~~~~~~~~ +- :py:meth:`~xarray.DataArray.differentiate` and + :py:meth:`~xarray.Dataset.differentiate` are newly added. + (:issue:`1332`) + By `Keisuke Fujii `_. - Default colormap for sequential and divergent data can now be set via :py:func:`~xarray.set_options()` (:issue:`2394`) diff --git a/xarray/__init__.py b/xarray/__init__.py index e2d24e6c294..e3898f348cc 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -10,7 +10,7 @@ from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine -from .core.computation import apply_ufunc, where, dot +from .core.computation import apply_ufunc, dot, where from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index c2417345f55..5e6b81a253d 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,6 +1,9 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion + import numpy as np +from dask import __version__ as dask_version import dask.array as da try: @@ -30,3 +33,124 @@ def isin(element, test_elements, assume_unique=False, invert=False): if invert: result = ~result return result + + +if LooseVersion(dask_version) > LooseVersion('1.19.2'): + gradient = da.gradient + +else: # pragma: no cover + # Copied from dask v0.19.2 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import math + from numbers import Integral, Real + + AxisError = np.AxisError + + def validate_axis(axis, ndim): + """ Validate an input to axis= keywords """ + if isinstance(axis, (tuple, list)): + return tuple(validate_axis(ax, ndim) for ax in axis) + if not isinstance(axis, Integral): + raise TypeError("Axis value must be an integer, got %s" % axis) + if axis < -ndim or axis >= ndim: + raise AxisError("Axis %d is out of bounds for array of dimension " + "%d" % (axis, ndim)) + if axis < 0: + axis += ndim + return axis + + def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs): + """ + x: nd-array + array of one block + coord: 1d-array or scalar + coordinate along which the gradient is computed. + axis: int + axis along which the gradient is computed + array_locs: + actual location along axis. None if coordinate is scalar + grad_kwargs: + keyword to be passed to np.gradient + """ + block_loc = block_id[axis] + if array_locs is not None: + coord = coord[array_locs[0][block_loc]:array_locs[1][block_loc]] + grad = np.gradient(x, coord, axis=axis, **grad_kwargs) + return grad + + def gradient(f, *varargs, **kwargs): + f = da.asarray(f) + + kwargs["edge_order"] = math.ceil(kwargs.get("edge_order", 1)) + if kwargs["edge_order"] > 2: + raise ValueError("edge_order must be less than or equal to 2.") + + drop_result_list = False + axis = kwargs.pop("axis", None) + if axis is None: + axis = tuple(range(f.ndim)) + elif isinstance(axis, Integral): + drop_result_list = True + axis = (axis,) + + axis = validate_axis(axis, f.ndim) + + if len(axis) != len(set(axis)): + raise ValueError("duplicate axes not allowed") + + axis = tuple(ax % f.ndim for ax in axis) + + if varargs == (): + varargs = (1,) + if len(varargs) == 1: + varargs = len(axis) * varargs + if len(varargs) != len(axis): + raise TypeError( + "Spacing must either be a single scalar, or a scalar / " + "1d-array per axis" + ) + + if issubclass(f.dtype.type, (np.bool8, Integral)): + f = f.astype(float) + elif issubclass(f.dtype.type, Real) and f.dtype.itemsize < 4: + f = f.astype(float) + + results = [] + for i, ax in enumerate(axis): + for c in f.chunks[ax]: + if np.min(c) < kwargs["edge_order"] + 1: + raise ValueError( + 'Chunk size must be larger than edge_order + 1. ' + 'Minimum chunk for aixs {} is {}. Rechunk to ' + 'proceed.'.format(np.min(c), ax)) + + if np.isscalar(varargs[i]): + array_locs = None + else: + if isinstance(varargs[i], da.Array): + raise NotImplementedError( + 'dask array coordinated is not supported.') + # coordinate position for each block taking overlap into + # account + chunk = np.array(f.chunks[ax]) + array_loc_stop = np.cumsum(chunk) + 1 + array_loc_start = array_loc_stop - chunk - 2 + array_loc_stop[-1] -= 1 + array_loc_start[0] = 0 + array_locs = (array_loc_start, array_loc_stop) + + results.append(f.map_overlap( + _gradient_kernel, + dtype=f.dtype, + depth={j: 1 if j == ax else 0 for j in range(f.ndim)}, + boundary="none", + coord=varargs[i], + axis=ax, + array_locs=array_locs, + grad_kwargs=kwargs, + )) + + if drop_result_list: + results = results[0] + + return results diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 937d38d30fa..f131b003a69 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2073,6 +2073,9 @@ def diff(self, dim, n=1, label='upper'): Coordinates: * x (x) int64 3 4 + See Also + -------- + DataArray.differentiate """ ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) @@ -2352,6 +2355,61 @@ def rank(self, dim, pct=False, keep_attrs=False): ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs) return self._from_temp_dataset(ds) + def differentiate(self, coord, edge_order=1, datetime_unit=None): + """ Differentiate the array with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord: str + The coordinate to be used to compute the gradient. + edge_order: 1 or 2. Default 1 + N-th order accurate differences at the boundaries. + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: DataArray + + See also + -------- + numpy.gradient: corresponding numpy function + + Examples + -------- + + >>> da = xr.DataArray(np.arange(12).reshape(4, 3), dims=['x', 'y'], + ... coords={'x': [0, 0.1, 1.1, 1.2]}) + >>> da + + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.differentiate('x') + + array([[30. , 30. , 30. ], + [27.545455, 27.545455, 27.545455], + [27.545455, 27.545455, 27.545455], + [30. , 30. , 30. ]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().differentiate( + coord, edge_order, datetime_unit) + return self._from_temp_dataset(ds) + # priority most be higher than Variable to properly work with binary ufuncs ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 89dba6605a6..9cf304858a6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -13,8 +13,8 @@ import xarray as xr from . import ( - alignment, duck_array_ops, formatting, groupby, indexing, ops, resample, - rolling, utils) + alignment, computation, duck_array_ops, formatting, groupby, indexing, ops, + resample, rolling, utils) from .. import conventions from .alignment import align from .common import ( @@ -31,7 +31,7 @@ OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) from .utils import ( Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, - ensure_us_time_resolution, hashable, maybe_wrap_array) + ensure_us_time_resolution, hashable, maybe_wrap_array, to_numeric) from .variable import IndexVariable, Variable, as_variable, broadcast_variables # list of attributes of pd.DatetimeIndex that are ndarrays of time info @@ -3417,6 +3417,9 @@ def diff(self, dim, n=1, label='upper'): Data variables: foo (x) int64 1 -1 + See Also + -------- + Dataset.differentiate """ if n == 0: return self @@ -3767,6 +3770,62 @@ def rank(self, dim, pct=False, keep_attrs=False): attrs = self.attrs if keep_attrs else None return self._replace_vars_and_dims(variables, coord_names, attrs=attrs) + def differentiate(self, coord, edge_order=1, datetime_unit=None): + """ Differentiate with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord: str + The coordinate to be used to compute the gradient. + edge_order: 1 or 2. Default 1 + N-th order accurate differences at the boundaries. + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: Dataset + + See also + -------- + numpy.gradient: corresponding numpy function + """ + from .variable import Variable + + if coord not in self.variables and coord not in self.dims: + raise ValueError('Coordinate {} does not exist.'.format(coord)) + + coord_var = self[coord].variable + if coord_var.ndim != 1: + raise ValueError('Coordinate {} must be 1 dimensional but is {}' + ' dimensional'.format(coord, coord_var.ndim)) + + dim = coord_var.dims[0] + coord_data = coord_var.data + if coord_data.dtype.kind in 'mM': + if datetime_unit is None: + datetime_unit, _ = np.datetime_data(coord_data.dtype) + coord_data = to_numeric(coord_data, datetime_unit=datetime_unit) + + variables = OrderedDict() + for k, v in self.variables.items(): + if (k in self.data_vars and dim in v.dims and + k not in self.coords): + v = to_numeric(v, datetime_unit=datetime_unit) + grad = duck_array_ops.gradient( + v.data, coord_data, edge_order=edge_order, + axis=v.get_axis_num(dim)) + variables[k] = Variable(v.dims, grad) + else: + variables[k] = v + return self._replace_vars_and_dims(variables) + @property def real(self): return self._unary_op(lambda x: x.real, keep_attrs=True)(self) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 17eb310f8db..ef89dba2ab8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -93,6 +93,14 @@ def isnull(data): einsum = _dask_or_eager_func('einsum', array_args=slice(1, None), requires_dask='0.17.3') + +def gradient(x, coord, axis, edge_order): + if isinstance(x, dask_array_type): + return dask_array_compat.gradient( + x, coord, axis=axis, edge_order=edge_order) + return npcompat.gradient(x, coord, axis=axis, edge_order=edge_order) + + masked_invalid = _dask_or_eager_func( 'masked_invalid', eager_module=np.ma, dask_module=getattr(dask_array, 'ma', None)) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 90aa4ffaeda..afb34d99115 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -11,7 +11,7 @@ from . import rolling from .computation import apply_ufunc from .pycompat import iteritems -from .utils import is_scalar, OrderedSet +from .utils import is_scalar, OrderedSet, to_numeric from .variable import Variable, broadcast_variables from .duck_array_ops import dask_array_type @@ -414,8 +414,8 @@ def _floatize_x(x, new_x): # offset (min(x)) and the variation (x - min(x)) can be # represented by float. xmin = np.min(x[i]) - x[i] = (x[i] - xmin).astype(np.float64) - new_x[i] = (new_x[i] - xmin).astype(np.float64) + x[i] = to_numeric(x[i], offset=xmin, dtype=np.float64) + new_x[i] = to_numeric(new_x[i], offset=xmin, dtype=np.float64) return x, new_x diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 6d4db063b98..22dff44acf8 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion import numpy as np try: @@ -97,3 +98,187 @@ def isin(element, test_elements, assume_unique=False, invert=False): element = np.asarray(element) return np.in1d(element, test_elements, assume_unique=assume_unique, invert=invert).reshape(element.shape) + + +if LooseVersion(np.__version__) >= LooseVersion('1.13'): + gradient = np.gradient +else: + def normalize_axis_tuple(axes, N): + if isinstance(axes, int): + axes = (axes, ) + return tuple([N + a if a < 0 else a for a in axes]) + + def gradient(f, *varargs, **kwargs): + f = np.asanyarray(f) + N = f.ndim # number of dimensions + + axes = kwargs.pop('axis', None) + if axes is None: + axes = tuple(range(N)) + else: + axes = normalize_axis_tuple(axes, N) + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and np.ndim(varargs[0]) == 0: + # single scalar for all axes + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + if np.ndim(distances) == 0: + continue + elif np.ndim(distances) != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError("when 1d, distances must match the " + "length of the corresponding dimension") + diffx = np.diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + edge_order = kwargs.pop('edge_order', 1) + if kwargs: + raise TypeError('"{}" are not valid keyword arguments.'.format( + '", "'.join(kwargs.keys()))) + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype.char + if otype not in ['f', 'd', 'F', 'D', 'm', 'M']: + otype = 'd' + + # Difference of datetime64 elements results in timedelta64 + if otype == 'M': + # Need to use the full dtype name because it contains unit + # information + otype = f.dtype.name.replace('datetime', 'timedelta') + elif otype == 'm': + # Needs to keep the specific units, can't be a general unit + otype = f.dtype + + # Convert datetime64 data into ints. Make dummy variable `y` + # that is a view of ints if the data is datetime64, otherwise + # just set y equal to the array `f`. + if f.dtype.char in ["M", "m"]: + y = f.view('int64') + else: + y = f + + for i, axis in enumerate(axes): + if y.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical " + "gradient, at least (edge_order + 1) elements are " + "required.") + # result allocation + out = np.empty_like(y, dtype=otype) + + uniform_spacing = np.ndim(dx[i]) == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[slice1] = (f[slice4] - f[slice2]) / (2. * dx[i]) + else: + dx1 = dx[i][0:-1] + dx2 = dx[i][1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = np.ones(N, dtype=int) + shape[axis] = -1 + a.shape = b.shape = c.shape = shape + # 1D equivalent -- + # out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4] + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = dx[i] if uniform_spacing else dx[i][0] + # 1D equivalent -- out[0] = (y[1] - y[0]) / (x[1] - x[0]) + out[slice1] = (y[slice2] - y[slice3]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = dx[i] if uniform_spacing else dx[i][-1] + # 1D equivalent -- out[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2]) + out[slice1] = (y[slice2] - y[slice3]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / dx[i] + b = 2. / dx[i] + c = -0.5 / dx[i] + else: + dx1 = dx[i][0] + dx2 = dx[i][1] + a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = - dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / dx[i] + b = -2. / dx[i] + c = 1.5 / dx[i] + else: + dx1 = dx[i][-2] + dx2 = dx[i][-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = - (dx2 + dx1) / (dx1 * dx2) + c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals diff --git a/xarray/core/utils.py b/xarray/core/utils.py index c3bb747fac5..9d129d5c4f4 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -591,3 +591,24 @@ def __iter__(self): def __len__(self): num_hidden = sum([k in self._hidden_keys for k in self._data]) return len(self._data) - num_hidden + + +def to_numeric(array, offset=None, datetime_unit=None, dtype=float): + """ + Make datetime array float + + offset: Scalar with the same type of array or None + If None, subtract minimum values to reduce round off error + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + dtype: target dtype + """ + if array.dtype.kind not in ['m', 'M']: + return array.astype(dtype) + if offset is None: + offset = np.min(array) + array = array - offset + + if datetime_unit: + return (array / np.timedelta64(1, datetime_unit)).astype(dtype) + return array.astype(dtype) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fc933960914..f8fb9b98ac3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -15,7 +15,7 @@ from xarray import ( DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends, broadcast, open_dataset, set_options) -from xarray.core import indexing, utils +from xarray.core import indexing, npcompat, utils from xarray.core.common import full_like from xarray.core.pycompat import ( OrderedDict, integer_types, iteritems, unicode_type) @@ -4513,3 +4513,78 @@ def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={'x': ('y', [1, 2, np.NaN])}) > 0 assert len(record) == 0 + + +@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_gradient(dask, edge_order): + rs = np.random.RandomState(42) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + ds = xr.Dataset({'var': da}) + + # along x + actual = da.differentiate('x', edge_order) + expected_x = xr.DataArray( + npcompat.gradient(da, da['x'], axis=0, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + assert_equal(ds['var'].differentiate('x', edge_order=edge_order), + ds.differentiate('x', edge_order=edge_order)['var']) + # coordinate should not change + assert_equal(da['x'], actual['x']) + + # along y + actual = da.differentiate('y', edge_order) + expected_y = xr.DataArray( + npcompat.gradient(da, da['y'], axis=1, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_y, actual) + assert_equal(actual, ds.differentiate('y', edge_order=edge_order)['var']) + assert_equal(ds['var'].differentiate('y', edge_order=edge_order), + ds.differentiate('y', edge_order=edge_order)['var']) + + with pytest.raises(ValueError): + da.differentiate('x2d') + + +@pytest.mark.parametrize('dask', [True, False]) +def test_gradient_datetime(dask): + rs = np.random.RandomState(42) + coord = np.array( + ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', + '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'], + dtype='datetime64') + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + # along x + actual = da.differentiate('x', edge_order=1, datetime_unit='D') + expected_x = xr.DataArray( + npcompat.gradient( + da, utils.to_numeric(da['x'], datetime_unit='D'), + axis=0, edge_order=1), dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + + actual2 = da.differentiate('x', edge_order=1, datetime_unit='h') + assert np.allclose(actual, actual2 * 24) + + # for datetime variable + actual = da['x'].differentiate('x', edge_order=1, datetime_unit='D') + assert np.allclose(actual, 1.0) + + # with different date unit + da = xr.DataArray(coord.astype('datetime64[ms]'), dims=['x'], + coords={'x': coord}) + actual = da.differentiate('x', edge_order=1) + assert np.allclose(actual, 1.0) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 3f32fc49fd2..b9712f60290 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -12,8 +12,8 @@ from xarray import DataArray, Dataset, concat from xarray.core import duck_array_ops, dtypes from xarray.core.duck_array_ops import ( - array_notnull_equiv, concatenate, count, first, last, mean, rolling_window, - stack, where) + array_notnull_equiv, concatenate, count, first, gradient, last, mean, + rolling_window, stack, where) from xarray.core.pycompat import dask_array_type from xarray.testing import assert_allclose, assert_equal @@ -417,6 +417,23 @@ def test_dask_rolling(axis, window, center): fill_value=np.nan) +@pytest.mark.skipif(not has_dask, reason='This is for dask.') +@pytest.mark.parametrize('axis', [0, -1, 1]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_dask_gradient(axis, edge_order): + import dask.array as da + + array = np.array(np.random.randn(100, 5, 40)) + x = np.exp(np.linspace(0, 1, array.shape[axis])) + + darray = da.from_array(array, chunks=[(6, 30, 30, 20, 14), 5, 8]) + expected = gradient(array, x, axis=axis, edge_order=edge_order) + actual = gradient(darray, x, axis=axis, edge_order=edge_order) + + assert isinstance(actual, da.Array) + assert_array_equal(actual, expected) + + @pytest.mark.parametrize('dim_num', [1, 2]) @pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) @pytest.mark.parametrize('dask', [False, True]) From 93f58a60fdb6260f5d5a156c78dca0d956c75fe3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 21 Sep 2018 18:41:50 -0700 Subject: [PATCH 26/51] Doc fixes for v0.10.9 --- doc/computation.rst | 6 ++++-- doc/roadmap.rst | 2 ++ doc/whats-new.rst | 19 +++++++++++++------ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/doc/computation.rst b/doc/computation.rst index 67cda6f2191..759c87a6cc7 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -208,15 +208,17 @@ coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by central finite differences using their coordinates, .. ipython:: python - a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[0.1, 0.11, 0.2, 0.3]) + + a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[[0.1, 0.11, 0.2, 0.3]]) a a.differentiate('x') This method can be used also for multidimensional arrays, .. ipython:: python + a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'], - coords=[0.1, 0.11, 0.2, 0.3]) + coords={'x': [0.1, 0.11, 0.2, 0.3]}) a.differentiate('x') .. note:: diff --git a/doc/roadmap.rst b/doc/roadmap.rst index 2708cb7cf8f..34d203c3f48 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -1,3 +1,5 @@ +.. _roadmap: + Development roadmap =================== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7240059bd10..4a6886159d1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,11 +27,17 @@ What's New .. _whats-new.0.10.9: -v0.10.9 (unreleased) --------------------- +v0.10.9 (21 September 2019) +--------------------------- -Documentation -~~~~~~~~~~~~~ +This minor release contains a number of backwards compatible enhancements. + +Announcements of note: + +- Xarray is now a NumFOCUS fiscally sponsored project! Read + `the anouncment `_ + for more details. +- We have a new :doc:`roadmap` that outlines our future development plans. Enhancements ~~~~~~~~~~~~ @@ -51,7 +57,8 @@ Enhancements (:issue:`2230`) By `Keisuke Fujii `_. -- :py:meth:`plot()` now accepts the kwargs ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. +- :py:meth:`plot()` now accepts the kwargs + ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. By `Deepak Cherian `_. (:issue:`2224`) - DataArray coordinates and Dataset coordinates and data variables are @@ -118,7 +125,7 @@ Bug fixes By `Keisuke Fujii `_. - Now :py:func:`xr.apply_ufunc` raises a ValueError when the size of -``input_core_dims`` is inconsistent with the number of arguments. + ``input_core_dims`` is inconsistent with the number of arguments. (:issue:`2341`) By `Keisuke Fujii `_. From c1ea99212bc3e26789090f6800d468ed7fdd1bb8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 21 Sep 2018 18:53:20 -0700 Subject: [PATCH 27/51] Revert to dev version for 0.11 --- doc/whats-new.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4a6886159d1..3bc2c521568 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,18 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ +.. _whats-new.0.11.0: + +v0.11.0 (unreleased) +-------------------- + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + + .. _whats-new.0.10.9: v0.10.9 (21 September 2019) From 4577ed891f9839722fd9e606e4f3bdb8e6acef4f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 22 Sep 2018 13:01:20 +0900 Subject: [PATCH 28/51] misc plotting fixes (#2426) 1. Don't explicitly set rotation for colorbar label. Labels now work better with horizontal colorbar. 2. facetgrid: check if artist is Mappable --- xarray/plot/facetgrid.py | 9 ++++++--- xarray/plot/plot.py | 2 +- xarray/plot/utils.py | 3 ++- xarray/tests/test_plot.py | 13 +++++++++++++ 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index a0d7c4dd5e2..792b7829bf4 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -502,9 +502,12 @@ def map(self, func, *args, **kwargs): data = self.data.loc[namedict] plt.sca(ax) innerargs = [data[a].values for a in args] - # TODO: is it possible to verify that an artist is mappable? - mappable = func(*innerargs, **kwargs) - self._mappables.append(mappable) + maybe_mappable = func(*innerargs, **kwargs) + # TODO: better way to verify that an artist is mappable? + # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 + if (maybe_mappable and + hasattr(maybe_mappable, 'autoscale_None')): + self._mappables.append(maybe_mappable) self._finalize_grid(*args[:2]) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index b92429b857d..a6add44682f 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -770,7 +770,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, cbar_kwargs.setdefault('cax', cbar_ax) cbar = plt.colorbar(primitive, **cbar_kwargs) if add_labels and 'label' not in cbar_kwargs: - cbar.set_label(label_from_attrs(darray), rotation=90) + cbar.set_label(label_from_attrs(darray)) elif cbar_ax is not None or cbar_kwargs is not None: # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 9af0624dbfc..455d27c3987 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1,10 +1,11 @@ from __future__ import absolute_import, division, print_function +import textwrap import warnings import numpy as np -import textwrap +from ..core.options import OPTIONS from ..core.pycompat import basestring from ..core.utils import is_scalar from ..core.options import OPTIONS diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index e27f03630b7..b3bc687a5c5 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1051,6 +1051,19 @@ def test_convenient_facetgrid_4d(self): for ax in g.axes.flat: assert ax.has_data() + @pytest.mark.filterwarnings('ignore:This figure includes') + def test_facetgrid_map_only_appends_mappables(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows') + + expected = g._mappables + + g.map(lambda: plt.plot(1, 1)) + actual = g._mappables + + assert expected == actual + def test_facetgrid_cmap(self): # Regression test for GH592 data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12)) From 04253f271c66a12366a82d357c2a889dd3eea42f Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Sat, 22 Sep 2018 13:13:28 -0700 Subject: [PATCH 29/51] dev/test build for python 3.7 (#2271) * dev/test build for python 3.7 * docs * fixup whatsnew * update 3.7 travis test --- .travis.yml | 2 ++ ci/requirements-py37.yml | 13 +++++++++++++ doc/installing.rst | 2 +- doc/whats-new.rst | 8 +++++--- 4 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 ci/requirements-py37.yml diff --git a/.travis.yml b/.travis.yml index 0e51e946da0..1e6c3254cdd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,6 +18,8 @@ matrix: env: CONDA_ENV=py35 - python: 3.6 env: CONDA_ENV=py36 + - python: 3.6 # TODO: change this to 3.7 once https://github.com/travis-ci/travis-ci/issues/9815 is fixed + env: CONDA_ENV=py37 - python: 3.6 env: - CONDA_ENV=py36 diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml new file mode 100644 index 00000000000..5f973936f63 --- /dev/null +++ b/ci/requirements-py37.yml @@ -0,0 +1,13 @@ +name: test_env +channels: + - defaults +dependencies: + - python=3.7 + - pip: + - pytest + - flake8 + - mock + - numpy + - pandas + - coveralls + - pytest-cov diff --git a/doc/installing.rst b/doc/installing.rst index 85cd5a02568..eb74eb7162b 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -6,7 +6,7 @@ Installation Required dependencies --------------------- -- Python 2.7 [1]_, 3.5, or 3.6 +- Python 2.7 [1]_, 3.5, 3.6, or 3.7 - `numpy `__ (1.12 or later) - `pandas `__ (0.19.2 or later) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3bc2c521568..67d0d548ec5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,6 +33,9 @@ v0.11.0 (unreleased) Enhancements ~~~~~~~~~~~~ +- Added support for Python 3.7. (:issue:`2271`). + By `Joe Hamman `_. + Bug fixes ~~~~~~~~~ @@ -78,8 +81,8 @@ Enhancements (:issue:`1186`) By `Seth P `_. - A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in - generating dates from standard or non-standard calendars. By `Spencer Clark - `_. + generating dates from standard or non-standard calendars. By `Spencer Clark + `_. - When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')`` (:issue:`2284`) @@ -178,7 +181,6 @@ Enhancements :py:meth:`~xarray.DataArray.from_cdms2` (:issue:`2262`). By `Stephane Raynaud `_. - Bug fixes ~~~~~~~~~ From 1ec83a75c409c68683ac035dfee1c26f8cbc6695 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 26 Sep 2018 09:55:32 +0900 Subject: [PATCH 30/51] make facetgrid execute _finalize() only once (#2435) * only apply finalize_grid once. * Add test --- xarray/plot/facetgrid.py | 16 ++++++++++------ xarray/tests/test_plot.py | 2 ++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 792b7829bf4..79a3993e23b 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -188,6 +188,7 @@ def __init__(self, data, col=None, row=None, col_wrap=None, self._y_var = None self._cmap_extend = None self._mappables = [] + self._finalized = False @property def _left_axes(self): @@ -308,13 +309,16 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs): def _finalize_grid(self, *axlabels): """Finalize the annotations and layout.""" - self.set_axis_labels(*axlabels) - self.set_titles() - self.fig.tight_layout() + if not self._finalized: + self.set_axis_labels(*axlabels) + self.set_titles() + self.fig.tight_layout() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): - if namedict is None: - ax.set_visible(False) + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): + if namedict is None: + ax.set_visible(False) + + self._finalized = True def add_legend(self, **kwargs): figlegend = self.fig.legend( diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b3bc687a5c5..1423f7ae853 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1523,7 +1523,9 @@ def test_num_ticks(self): @pytest.mark.slow def test_map(self): + assert self.g._finalized is False self.g.map(plt.contourf, 'x', 'y', Ellipsis) + assert self.g._finalized is True self.g.map(lambda: None) @pytest.mark.slow From 1857a7fc2ab3a472025ff5d69371feaf7e3c4d74 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 26 Sep 2018 16:27:54 -0700 Subject: [PATCH 31/51] switch travis language to generic (#2432) * dev/test build for python 3.7 * docs * fixup whatsnew * update 3.7 travis test * generic travis config * switch to minimal --- .travis.yml | 80 +++++++++++++++++++---------------------------------- 1 file changed, 28 insertions(+), 52 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1e6c3254cdd..defb37ec8aa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ # Based on http://conda.pydata.org/docs/travis.html -language: python +language: minimal sudo: false # use container based build notifications: email: false @@ -10,74 +10,48 @@ branches: matrix: fast_finish: true include: - - python: 2.7 - env: CONDA_ENV=py27-min - - python: 2.7 - env: CONDA_ENV=py27-cdat+iris+pynio - - python: 3.5 - env: CONDA_ENV=py35 - - python: 3.6 - env: CONDA_ENV=py36 - - python: 3.6 # TODO: change this to 3.7 once https://github.com/travis-ci/travis-ci/issues/9815 is fixed - env: CONDA_ENV=py37 - - python: 3.6 - env: + - env: CONDA_ENV=py27-min + - env: CONDA_ENV=py27-cdat+iris+pynio + - env: CONDA_ENV=py35 + - env: CONDA_ENV=py36 + - env: CONDA_ENV=py37 + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-dask-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-rasterio-0.36 - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev - - python: 3.5 - env: CONDA_ENV=docs - - python: 3.6 - env: CONDA_ENV=py36-hypothesis + - env: CONDA_ENV=py36-dask-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-rasterio-0.36 + - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=docs + - env: CONDA_ENV=py36-hypothesis + allow_failures: - - python: 3.6 - env: + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-zarr-dev before_install: - - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then - wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh; - else - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - fi + - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - hash -r @@ -97,6 +71,8 @@ install: - python xarray/util/print_versions.py script: + - which python + - python --version - python -OO -c "import xarray" - if [[ "$CONDA_ENV" == "docs" ]]; then conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc; From 96dde664eda26a76f934151dd10dc02f6cb0000b Mon Sep 17 00:00:00 2001 From: Zac Hatfield-Dodds Date: Thu, 27 Sep 2018 09:47:27 +1000 Subject: [PATCH 32/51] Use profile mechanism, not no-op mutation (#2442) --- properties/test_encode_decode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 8d84c0f6815..7b3e75fbf0c 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -13,7 +13,8 @@ import xarray as xr # Run for a while - arrays are a bigger search space than usual -settings.deadline = None +settings.register_profile("ci", deadline=None) +settings.load_profile("ci") an_array = npst.arrays( From 78058e2c1f39cbfae6eddb30e3b7d4a81b54ad8b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 27 Sep 2018 18:37:02 +0200 Subject: [PATCH 33/51] Remove incorrect statement about "drop" in the text docs (#2439) As pointed out by FaustinCarter in GH1949 --- doc/data-structures.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 10d83ca448f..618ccccff3e 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -408,13 +408,6 @@ operations keep around coordinates: list(ds[['x']]) list(ds.drop('temperature')) -If a dimension name is given as an argument to ``drop``, it also drops all -variables that use that dimension: - -.. ipython:: python - - list(ds.drop('time')) - As an alternate to dictionary-like modifications, you can use :py:meth:`~xarray.Dataset.assign` and :py:meth:`~xarray.Dataset.assign_coords`. These methods return a new dataset with additional (or replaced) or values: From 638b251c622359b665208276a2cb23b0fbc5141b Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 28 Sep 2018 08:54:29 +0200 Subject: [PATCH 34/51] Future warning for default reduction dimension of groupby (#2366) * warn the default reduction dimension of groupby * Only use DEFAULT_DIMS in groupby/resample * Restore unintended line break * Lint * Add whatsnew * Support dataset.groupby * Add a version in the warning message. * is -> == * Update tests * Update docs for DatasetResample.reduce * Match dataset.resample behavior to the current one. * Update via comments. --- doc/whats-new.rst | 10 ++++++ xarray/__init__.py | 2 ++ xarray/core/common.py | 6 +++- xarray/core/dataset.py | 11 +++--- xarray/core/groupby.py | 64 ++++++++++++++++++++++++++++++++-- xarray/core/resample.py | 12 ++++--- xarray/core/variable.py | 2 ++ xarray/tests/test_dask.py | 4 +-- xarray/tests/test_dataarray.py | 36 +++++++++++++++---- xarray/tests/test_dataset.py | 23 +++++++----- 10 files changed, 141 insertions(+), 29 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 67d0d548ec5..4e1607f0e42 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,16 @@ What's New v0.11.0 (unreleased) -------------------- +Breaking changes +~~~~~~~~~~~~~~~~ + +- Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` + without dimension argument will change in the next release. + Now we warn a FutureWarning. + By `Keisuke Fujii `_. + +Documentation +~~~~~~~~~~~~~ Enhancements ~~~~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index e3898f348cc..59a961c6b56 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -34,3 +34,5 @@ from . import tutorial from . import ufuncs from . import testing + +from .core.common import ALL_DIMS diff --git a/xarray/core/common.py b/xarray/core/common.py index 280034a30dd..41e4fec2982 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -10,7 +10,11 @@ from . import duck_array_ops, dtypes, formatting, ops from .arithmetic import SupportsArithmetic from .pycompat import OrderedDict, basestring, dask_array_type, suppress -from .utils import either_dict_or_kwargs, Frozen, SortedKeysDict +from .utils import either_dict_or_kwargs, Frozen, SortedKeysDict, ReprObject + + +# Used as a sentinel value to indicate a all dimensions +ALL_DIMS = ReprObject('') class ImplementsArrayReduce(object): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9cf304858a6..981ad3157ba 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -18,7 +18,8 @@ from .. import conventions from .alignment import align from .common import ( - DataWithCoords, ImplementsDatasetReduce, _contains_datetime_like_objects) + ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, + _contains_datetime_like_objects) from .coordinates import ( DatasetCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -743,7 +744,7 @@ def copy(self, deep=False, data=None): Shallow copy versus deep copy >>> da = xr.DataArray(np.random.randn(2, 3)) - >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])}, + >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])}, coords={'x': ['one', 'two']}) >>> ds.copy() @@ -775,7 +776,7 @@ def copy(self, deep=False, data=None): foo (dim_0, dim_1) float64 7.0 0.3897 -1.862 -0.6091 -1.051 -0.3003 bar (x) int64 -1 2 - Changing the data using the ``data`` argument maintains the + Changing the data using the ``data`` argument maintains the structure of the original object, but with the new data. Original object is unaffected. @@ -826,7 +827,7 @@ def copy(self, deep=False, data=None): # skip __init__ to avoid costly validation return self._construct_direct(variables, self._coord_names.copy(), self._dims.copy(), self._attrs_copy(), - encoding=self.encoding) + encoding=self.encoding) def _subset_with_all_valid_coords(self, variables, coord_names, attrs): needed_dims = set() @@ -2893,6 +2894,8 @@ def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, Dataset with this object's DataArrays replaced with new DataArrays of summarized data and the indicated dimension(s) removed. """ + if dim is ALL_DIMS: + dim = None if isinstance(dim, basestring): dims = set([dim]) elif dim is None: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7068f8e6cae..3842c642047 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,14 +1,15 @@ from __future__ import absolute_import, division, print_function import functools +import warnings import numpy as np import pandas as pd -from . import dtypes, duck_array_ops, nputils, ops +from . import dtypes, duck_array_ops, nputils, ops, utils from .arithmetic import SupportsArithmetic from .combine import concat -from .common import ImplementsArrayReduce, ImplementsDatasetReduce +from .common import ALL_DIMS, ImplementsArrayReduce, ImplementsDatasetReduce from .pycompat import integer_types, range, zip from .utils import hashable, maybe_wrap_array, peek_at, safe_cast_to_index from .variable import IndexVariable, Variable, as_variable @@ -567,10 +568,39 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process + if self._obj.ndim > 1: + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + def reduce_array(ar): return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) return self.apply(reduce_array, shortcut=shortcut) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, skipna=None, + keep_attrs=False, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + skipna=skipna, allow_lazy=True, **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, + keep_attrs=False, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + allow_lazy=True, **kwargs) + return wrapped_func + + +DEFAULT_DIMS = utils.ReprObject('') ops.inject_reduce_methods(DataArrayGroupBy) ops.inject_binary_ops(DataArrayGroupBy) @@ -649,10 +679,40 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process. Do not forget to remove _reduce_method + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + elif dim is None: + dim = self._group_dim + def reduce_dataset(ds): return ds.reduce(func, dim, keep_attrs, **kwargs) return self.apply(reduce_dataset) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, keep_attrs=False, + skipna=None, **kwargs): + return self.reduce(func, dim, keep_attrs, skipna=skipna, + numeric_only=numeric_only, allow_lazy=True, + **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, keep_attrs=False, + **kwargs): + return self.reduce(func, dim, keep_attrs, + numeric_only=numeric_only, allow_lazy=True, + **kwargs) + return wrapped_func + def assign(self, **kwargs): """Assign data variables by group. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 4933a09b257..25c149c51af 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function from . import ops -from .groupby import DataArrayGroupBy, DatasetGroupBy +from .groupby import DataArrayGroupBy, DatasetGroupBy, DEFAULT_DIMS from .pycompat import OrderedDict, dask_array_type RESAMPLE_DIM = '__resample_dim__' @@ -277,15 +277,14 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): """Reduce the items in this group by applying `func` along the pre-defined resampling dimension. - Note that `dim` is by default here and ignored if passed by the user; - this ensures compatibility with the existing reduce interface. - Parameters ---------- func : function Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. + dim : str or sequence of str, optional + Dimension(s) over which to apply `func`. keep_attrs : bool, optional If True, the datasets's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -299,8 +298,11 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = None + return super(DatasetResample, self).reduce( - func, self._dim, keep_attrs, **kwargs) + func, dim, keep_attrs, **kwargs) def _interpolate(self, kind='linear'): """Apply scipy.interpolate.interp1d along resampling dimension.""" diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 86629cc2a28..c003d52aab2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1333,6 +1333,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim is common.ALL_DIMS: + dim = None if dim is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 6ca83ab73ab..43fa35473ce 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -385,8 +385,8 @@ def test_groupby(self): u = self.eager_array v = self.lazy_array - expected = u.groupby('x').mean() - actual = v.groupby('x').mean() + expected = u.groupby('x').mean(xr.ALL_DIMS) + actual = v.groupby('x').mean(xr.ALL_DIMS) self.assertLazyAndAllClose(expected, actual) def test_groupby_first(self): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 2b93e696d50..f8b288f4ab0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2,6 +2,7 @@ import pickle from copy import deepcopy +from distutils.version import LooseVersion from textwrap import dedent import warnings @@ -14,7 +15,7 @@ DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options) from xarray.convert import from_cdms2 from xarray.coding.times import CFDatetimeCoder, _import_cftime -from xarray.core.common import full_like +from xarray.core.common import full_like, ALL_DIMS from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal, @@ -2000,15 +2001,15 @@ def test_groupby_sum(self): self.x[:, 10:].sum(), self.x[:, 9:10].sum()]).T), 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] - assert_allclose(expected_sum_all, grouped.reduce(np.sum)) - assert_allclose(expected_sum_all, grouped.sum()) + assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=ALL_DIMS)) + assert_allclose(expected_sum_all, grouped.sum(ALL_DIMS)) expected = DataArray([array['y'].values[idx].sum() for idx in [slice(9), slice(10, None), slice(9, 10)]], [['a', 'b', 'c']], ['abc']) actual = array['y'].groupby('abc').apply(np.sum) assert_allclose(expected, actual) - actual = array['y'].groupby('abc').sum() + actual = array['y'].groupby('abc').sum(ALL_DIMS) assert_allclose(expected, actual) expected_sum_axis1 = Dataset( @@ -2019,6 +2020,27 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, 'y')) assert_allclose(expected_sum_axis1, grouped.sum('y')) + def test_groupby_warning(self): + array = self.make_groupby_example_array() + grouped = array.groupby('y') + with pytest.warns(FutureWarning): + grouped.sum() + + @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'), + reason="not to forget the behavior change") + def test_groupby_sum_default(self): + array = self.make_groupby_example_array() + grouped = array.groupby('abc') + + expected_sum_all = Dataset( + {'foo': Variable(['x', 'abc'], + np.array([self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1)]).T), + 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] + + assert_allclose(expected_sum_all, grouped.sum()) + def test_groupby_count(self): array = DataArray( [0, 0, np.nan, np.nan, 0, 0], @@ -2099,9 +2121,9 @@ def test_groupby_math(self): assert_identical(expected, actual) grouped = array.groupby('abc') - expected_agg = (grouped.mean() - np.arange(3)).rename(None) + expected_agg = (grouped.mean(ALL_DIMS) - np.arange(3)).rename(None) actual = grouped - DataArray(range(3), [('abc', ['a', 'b', 'c'])]) - actual_agg = actual.groupby('abc').mean() + actual_agg = actual.groupby('abc').mean(ALL_DIMS) assert_allclose(expected_agg, actual_agg) with raises_regex(TypeError, 'only support binary ops'): @@ -2175,7 +2197,7 @@ def test_groupby_multidim(self): ('lon', DataArray([5, 28, 23], coords=[('lon', [30., 40., 50.])])), ('lat', DataArray([16, 40], coords=[('lat', [10., 20.])]))]: - actual_sum = array.groupby(dim).sum() + actual_sum = array.groupby(dim).sum(ALL_DIMS) assert_identical(expected_sum, actual_sum) def test_groupby_multidim_apply(self): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f8fb9b98ac3..237dc09d06a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -14,7 +14,7 @@ import xarray as xr from xarray import ( DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends, - broadcast, open_dataset, set_options) + broadcast, open_dataset, set_options, ALL_DIMS) from xarray.core import indexing, npcompat, utils from xarray.core.common import full_like from xarray.core.pycompat import ( @@ -2648,20 +2648,28 @@ def test_groupby_reduce(self): expected = data.mean('y') expected['yonly'] = expected['yonly'].variable.set_dims({'x': 3}) - actual = data.groupby('x').mean() + actual = data.groupby('x').mean(ALL_DIMS) assert_allclose(expected, actual) actual = data.groupby('x').mean('y') assert_allclose(expected, actual) letters = data['letters'] - expected = Dataset({'xy': data['xy'].groupby(letters).mean(), + expected = Dataset({'xy': data['xy'].groupby(letters).mean(ALL_DIMS), 'xonly': (data['xonly'].mean().variable .set_dims({'letters': 2})), 'yonly': data['yonly'].groupby(letters).mean()}) - actual = data.groupby('letters').mean() + actual = data.groupby('letters').mean(ALL_DIMS) assert_allclose(expected, actual) + def test_groupby_warn(self): + data = Dataset({'xy': (['x', 'y'], np.random.randn(3, 4)), + 'xonly': ('x', np.random.randn(3)), + 'yonly': ('y', np.random.randn(4)), + 'letters': ('y', ['a', 'a', 'b', 'b'])}) + with pytest.warns(FutureWarning): + data.groupby('x').mean() + def test_groupby_math(self): def reorder_dims(x): return x.transpose('dim1', 'dim2', 'dim3', 'time') @@ -2716,7 +2724,7 @@ def test_groupby_math_virtual(self): ds = Dataset({'x': ('t', [1, 2, 3])}, {'t': pd.date_range('20100101', periods=3)}) grouped = ds.groupby('t.day') - actual = grouped - grouped.mean() + actual = grouped - grouped.mean(ALL_DIMS) expected = Dataset({'x': ('t', [0, 0, 0])}, ds[['t', 't.day']]) assert_identical(actual, expected) @@ -2725,18 +2733,17 @@ def test_groupby_nan(self): # nan should be excluded from groupby ds = Dataset({'foo': ('x', [1, 2, 3, 4])}, {'bar': ('x', [1, 1, 2, np.nan])}) - actual = ds.groupby('bar').mean() + actual = ds.groupby('bar').mean(ALL_DIMS) expected = Dataset({'foo': ('bar', [1.5, 3]), 'bar': [1, 2]}) assert_identical(actual, expected) def test_groupby_order(self): # groupby should preserve variables order - ds = Dataset() for vn in ['a', 'b', 'c']: ds[vn] = DataArray(np.arange(10), dims=['t']) data_vars_ref = list(ds.data_vars.keys()) - ds = ds.groupby('t').mean() + ds = ds.groupby('t').mean(ALL_DIMS) data_vars = list(ds.data_vars.keys()) assert data_vars == data_vars_ref # coords are now at the end of the list, so the test below fails From 458cf51ce20e8d924b38b59c8fbc3bb10f39148e Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 28 Sep 2018 15:44:28 +0200 Subject: [PATCH 35/51] restore ddof support in std (#2447) * restore ddof support in std * whats new --- doc/whats-new.rst | 3 +++ xarray/core/nanops.py | 4 ++-- xarray/tests/test_duck_array_ops.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4e1607f0e42..40d21cc5346 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ Enhancements Bug fixes ~~~~~~~~~ +- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. + (:issue:`2240`) + By `Keisuke Fujii `_. .. _whats-new.0.10.9: diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 2309ed9619d..9549c8e77b9 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -184,9 +184,9 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0): a, axis=axis, dtype=dtype, ddof=ddof) -def nanstd(a, axis=None, dtype=None, out=None): +def nanstd(a, axis=None, dtype=None, out=None, ddof=0): return _dask_or_eager_func('nanstd', eager_module=nputils)( - a, axis=axis, dtype=dtype) + a, axis=axis, dtype=dtype, ddof=ddof) def nanprod(a, axis=None, dtype=None, out=None, min_count=None): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index b9712f60290..aab5d305a82 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -309,7 +309,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): assert_allclose(actual, expected, rtol=rtol) # make sure the compatiblility with pandas' results. - if func == 'var': + if func in ['var', 'std']: expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=0) assert_allclose(actual, expected, rtol=rtol) From c2b09d697c741b5d6ddede0ba01076c0cb09cf19 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Fri, 28 Sep 2018 09:44:54 -0400 Subject: [PATCH 36/51] Enable use of cftime.datetime coordinates with differentiate and interp (#2434) * Enable use of cftime.datetime coords with differentiate and interp * Raise TypeError for non-datetime x_new * Rename to_numeric to datetime_to_numeric --- doc/interpolation.rst | 3 ++ doc/time-series.rst | 56 ++++++++++++++------- doc/whats-new.rst | 7 +++ xarray/coding/cftimeindex.py | 29 +++++++++++ xarray/core/dataset.py | 51 +++++++++++++------ xarray/core/missing.py | 12 +++-- xarray/core/utils.py | 17 ++++--- xarray/tests/test_cftimeindex.py | 23 ++++++++- xarray/tests/test_dataset.py | 40 +++++++++++++-- xarray/tests/test_interp.py | 85 +++++++++++++++++++++++++++++++- xarray/tests/test_utils.py | 43 +++++++++++++++- 11 files changed, 312 insertions(+), 54 deletions(-) diff --git a/doc/interpolation.rst b/doc/interpolation.rst index e5230e95dae..10e46331d0a 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -63,6 +63,9 @@ by specifing the time periods required. da_dt64.interp(time=pd.date_range('1/1/2000', '1/3/2000', periods=3)) +Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also +allowed. See :ref:`CFTimeIndex` for examples. + .. note:: Currently, our interpolation only works for regular grids. diff --git a/doc/time-series.rst b/doc/time-series.rst index d99c3218d18..c1a686b409f 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -70,9 +70,9 @@ You can manual decode arrays in this form by passing a dataset to One unfortunate limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262. When a netCDF file contains dates outside of these bounds, dates will be -returned as arrays of ``cftime.datetime`` objects and a ``CFTimeIndex`` -can be used for indexing. The ``CFTimeIndex`` enables only a subset of -the indexing functionality of a ``pandas.DatetimeIndex`` and is only enabled +returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` +can be used for indexing. The :py:class:`~xarray.CFTimeIndex` enables only a subset of +the indexing functionality of a :py:class:`pandas.DatetimeIndex` and is only enabled when using the standalone version of ``cftime`` (not the version packaged with earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more information. @@ -219,12 +219,12 @@ Non-standard calendars and dates outside the Timestamp-valid range ------------------------------------------------------------------ Through the standalone ``cftime`` library and a custom subclass of -``pandas.Index``, xarray supports a subset of the indexing functionality enabled -through the standard ``pandas.DatetimeIndex`` for dates from non-standard -calendars or dates using a standard calendar, but outside the -`Timestamp-valid range`_ (approximately between years 1678 and 2262). This -behavior has not yet been turned on by default; to take advantage of this -functionality, you must have the ``enable_cftimeindex`` option set to +:py:class:`pandas.Index`, xarray supports a subset of the indexing +functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for +dates from non-standard calendars or dates using a standard calendar, but +outside the `Timestamp-valid range`_ (approximately between years 1678 and +2262). This behavior has not yet been turned on by default; to take advantage +of this functionality, you must have the ``enable_cftimeindex`` option set to ``True`` within your context (see :py:func:`~xarray.set_options` for more information). It is expected that this will become the default behavior in xarray version 0.11. @@ -232,7 +232,7 @@ xarray version 0.11. For instance, you can create a DataArray indexed by a time coordinate with a no-leap calendar within a context manager setting the ``enable_cftimeindex`` option, and the time index will be cast to a -``CFTimeIndex``: +:py:class:`~xarray.CFTimeIndex`: .. ipython:: python @@ -247,28 +247,28 @@ coordinate with a no-leap calendar within a context manager setting the .. note:: - With the ``enable_cftimeindex`` option activated, a ``CFTimeIndex`` + With the ``enable_cftimeindex`` option activated, a :py:class:`~xarray.CFTimeIndex` will be used for time indexing if any of the following are true: - The dates are from a non-standard calendar - Any dates are outside the Timestamp-valid range - Otherwise a ``pandas.DatetimeIndex`` will be used. In addition, if any + Otherwise a :py:class:`pandas.DatetimeIndex` will be used. In addition, if any variable (not just an index variable) is encoded using a non-standard - calendar, its times will be decoded into ``cftime.datetime`` objects, + calendar, its times will be decoded into :py:class:`cftime.datetime` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. -xarray also includes a :py:func:`cftime_range` function, which enables creating a -``CFTimeIndex`` with regularly-spaced dates. For instance, we can create the -same dates and DataArray we created above using: +xarray also includes a :py:func:`~xarray.cftime_range` function, which enables +creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For instance, we can +create the same dates and DataArray we created above using: .. ipython:: python dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') -For data indexed by a ``CFTimeIndex`` xarray currently supports: +For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: - `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial datetime strings: @@ -294,7 +294,25 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports: .. ipython:: python da.groupby('time.month').sum() - + +- Interpolation using :py:class:`cftime.datetime` objects: + +.. ipython:: python + + da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) + +- Interpolation using datetime strings: + +.. ipython:: python + + da.interp(time=['0001-01-15', '0001-02-15']) + +- Differentiation: + +.. ipython:: python + + da.differentiate('time') + - And serialization: .. ipython:: python @@ -305,7 +323,7 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports: .. note:: Currently resampling along the time dimension for data indexed by a - ``CFTimeIndex`` is not supported. + :py:class:`~xarray.CFTimeIndex` is not supported. .. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations .. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601 diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 40d21cc5346..a5b7b36142e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,6 +46,13 @@ Enhancements - Added support for Python 3.7. (:issue:`2271`). By `Joe Hamman `_. +- Added support for using ``cftime.datetime`` coordinates with + :py:meth:`~xarray.DataArray.differentiate`, + :py:meth:`~xarray.Dataset.differentiate`, + :py:meth:`~xarray.DataArray.interp`, and + :py:meth:`~xarray.Dataset.interp`. + By `Spencer Clark `_ + Bug fixes ~~~~~~~~~ diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index ea2bcbc5858..e236dca3693 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -314,3 +314,32 @@ def __contains__(self, key): def contains(self, key): """Needed for .loc based partial-string indexing""" return self.__contains__(key) + + +def _parse_iso8601_without_reso(date_type, datetime_str): + date, _ = _parse_iso8601_with_reso(date_type, datetime_str) + return date + + +def _parse_array_of_cftime_strings(strings, date_type): + """Create a numpy array from an array of strings. + + For use in generating dates from strings for use with interp. Assumes the + array is either 0-dimensional or 1-dimensional. + + Parameters + ---------- + strings : array of strings + Strings to convert to dates + date_type : cftime.datetime type + Calendar type to use for dates + + Returns + ------- + np.array + """ + if strings.ndim == 0: + return np.array(_parse_iso8601_without_reso(date_type, strings.item())) + else: + return np.array([_parse_iso8601_without_reso(date_type, s) + for s in strings]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 981ad3157ba..4ad5902ebad 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -32,9 +32,11 @@ OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) from .utils import ( Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, - ensure_us_time_resolution, hashable, maybe_wrap_array, to_numeric) + ensure_us_time_resolution, hashable, maybe_wrap_array, datetime_to_numeric) from .variable import IndexVariable, Variable, as_variable, broadcast_variables +from ..coding.cftimeindex import _parse_array_of_cftime_strings + # list of attributes of pd.DatetimeIndex that are ndarrays of time info _DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond', 'nanosecond', 'date', @@ -1413,8 +1415,8 @@ def _validate_indexers(self, indexers): """ Here we make sure + indexer has a valid keys + indexer is in a valid data type - * string indexers are cast to datetime64 - if associated index is DatetimeIndex + + string indexers are cast to the appropriate date type if the + associated index is a DatetimeIndex or CFTimeIndex """ from .dataarray import DataArray @@ -1436,10 +1438,12 @@ def _validate_indexers(self, indexers): else: v = np.asarray(v) - if ((v.dtype.kind == 'U' or v.dtype.kind == 'S') - and isinstance(self.coords[k].to_index(), - pd.DatetimeIndex)): - v = v.astype('datetime64[ns]') + if v.dtype.kind == 'U' or v.dtype.kind == 'S': + index = self.indexes[k] + if isinstance(index, pd.DatetimeIndex): + v = v.astype('datetime64[ns]') + elif isinstance(index, xr.CFTimeIndex): + v = _parse_array_of_cftime_strings(v, index.date_type) if v.ndim == 0: v = as_variable(v) @@ -1981,11 +1985,26 @@ def maybe_variable(obj, k): except KeyError: return as_variable((k, range(obj.dims[k]))) + def _validate_interp_indexer(x, new_x): + # In the case of datetimes, the restrictions placed on indexers + # used with interp are stronger than those which are placed on + # isel, so we need an additional check after _validate_indexers. + if (_contains_datetime_like_objects(x) and + not _contains_datetime_like_objects(new_x)): + raise TypeError('When interpolating over a datetime-like ' + 'coordinate, the coordinates to ' + 'interpolate to must be either datetime ' + 'strings or datetimes. ' + 'Instead got\n{}'.format(new_x)) + else: + return (x, new_x) + variables = OrderedDict() for name, var in iteritems(obj._variables): if name not in indexers: if var.dtype.kind in 'uifc': - var_indexers = {k: (maybe_variable(obj, k), v) for k, v + var_indexers = {k: _validate_interp_indexer( + maybe_variable(obj, k), v) for k, v in indexers.items() if k in var.dims} variables[name] = missing.interp( var, var_indexers, method, **kwargs) @@ -3810,19 +3829,21 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): ' dimensional'.format(coord, coord_var.ndim)) dim = coord_var.dims[0] - coord_data = coord_var.data - if coord_data.dtype.kind in 'mM': - if datetime_unit is None: - datetime_unit, _ = np.datetime_data(coord_data.dtype) - coord_data = to_numeric(coord_data, datetime_unit=datetime_unit) + if _contains_datetime_like_objects(coord_var): + if coord_var.dtype.kind in 'mM' and datetime_unit is None: + datetime_unit, _ = np.datetime_data(coord_var.dtype) + elif datetime_unit is None: + datetime_unit = 's' # Default to seconds for cftime objects + coord_var = datetime_to_numeric(coord_var, datetime_unit=datetime_unit) variables = OrderedDict() for k, v in self.variables.items(): if (k in self.data_vars and dim in v.dims and k not in self.coords): - v = to_numeric(v, datetime_unit=datetime_unit) + if _contains_datetime_like_objects(v): + v = datetime_to_numeric(v, datetime_unit=datetime_unit) grad = duck_array_ops.gradient( - v.data, coord_data, edge_order=edge_order, + v.data, coord_var, edge_order=edge_order, axis=v.get_axis_num(dim)) variables[k] = Variable(v.dims, grad) else: diff --git a/xarray/core/missing.py b/xarray/core/missing.py index afb34d99115..0b560c277ae 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -9,9 +9,10 @@ import pandas as pd from . import rolling +from .common import _contains_datetime_like_objects from .computation import apply_ufunc from .pycompat import iteritems -from .utils import is_scalar, OrderedSet, to_numeric +from .utils import is_scalar, OrderedSet, datetime_to_numeric from .variable import Variable, broadcast_variables from .duck_array_ops import dask_array_type @@ -407,15 +408,16 @@ def _floatize_x(x, new_x): x = list(x) new_x = list(new_x) for i in range(len(x)): - if x[i].dtype.kind in 'Mm': + if _contains_datetime_like_objects(x[i]): # Scipy casts coordinates to np.float64, which is not accurate # enough for datetime64 (uses 64bit integer). # We assume that the most of the bits are used to represent the # offset (min(x)) and the variation (x - min(x)) can be # represented by float. - xmin = np.min(x[i]) - x[i] = to_numeric(x[i], offset=xmin, dtype=np.float64) - new_x[i] = to_numeric(new_x[i], offset=xmin, dtype=np.float64) + xmin = x[i].min() + x[i] = datetime_to_numeric(x[i], offset=xmin, dtype=np.float64) + new_x[i] = datetime_to_numeric( + new_x[i], offset=xmin, dtype=np.float64) return x, new_x diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 9d129d5c4f4..c39a07e1b5a 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -593,20 +593,25 @@ def __len__(self): return len(self._data) - num_hidden -def to_numeric(array, offset=None, datetime_unit=None, dtype=float): - """ - Make datetime array float +def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): + """Convert an array containing datetime-like data to an array of floats. + Parameters + ---------- + da : array + Input data offset: Scalar with the same type of array or None If None, subtract minimum values to reduce round off error datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as'} dtype: target dtype + + Returns + ------- + array """ - if array.dtype.kind not in ['m', 'M']: - return array.astype(dtype) if offset is None: - offset = np.min(array) + offset = array.min() array = array - offset if datetime_unit: diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index f72c6904f0e..62a29a15247 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -9,10 +9,11 @@ from datetime import timedelta from xarray.coding.cftimeindex import ( parse_iso8601, CFTimeIndex, assert_all_valid_date_type, - _parsed_string_to_bounds, _parse_iso8601_with_reso) + _parsed_string_to_bounds, _parse_iso8601_with_reso, + _parse_array_of_cftime_strings) from xarray.tests import assert_array_equal, assert_identical -from . import has_cftime, has_cftime_or_netCDF4 +from . import has_cftime, has_cftime_or_netCDF4, requires_cftime from .test_coding_times import _all_cftime_date_types @@ -616,3 +617,21 @@ def test_concat_cftimeindex(date_type, enable_cftimeindex): def test_empty_cftimeindex(): index = CFTimeIndex([]) assert index.date_type is None + + +@requires_cftime +def test_parse_array_of_cftime_strings(): + from cftime import DatetimeNoLeap + + strings = np.array(['2000-01-01', '2000-01-02']) + expected = np.array([DatetimeNoLeap(2000, 1, 1), + DatetimeNoLeap(2000, 1, 2)]) + + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + # Test scalar array case + strings = np.array('2000-01-01') + expected = np.array(DatetimeNoLeap(2000, 1, 1)) + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 237dc09d06a..c42b84c05fc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -22,8 +22,9 @@ from . import ( InaccessibleArray, TestCase, UnexpectedDataAccess, assert_allclose, - assert_array_equal, assert_equal, assert_identical, has_dask, raises_regex, - requires_bottleneck, requires_dask, requires_scipy, source_ndarray) + assert_array_equal, assert_equal, assert_identical, has_cftime, + has_dask, raises_regex, requires_bottleneck, requires_dask, requires_scipy, + source_ndarray) try: import cPickle as pickle @@ -4524,7 +4525,7 @@ def test_raise_no_warning_for_nan_in_binary_ops(): @pytest.mark.parametrize('dask', [True, False]) @pytest.mark.parametrize('edge_order', [1, 2]) -def test_gradient(dask, edge_order): +def test_differentiate(dask, edge_order): rs = np.random.RandomState(42) coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] @@ -4562,7 +4563,7 @@ def test_gradient(dask, edge_order): @pytest.mark.parametrize('dask', [True, False]) -def test_gradient_datetime(dask): +def test_differentiate_datetime(dask): rs = np.random.RandomState(42) coord = np.array( ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', @@ -4579,7 +4580,7 @@ def test_gradient_datetime(dask): actual = da.differentiate('x', edge_order=1, datetime_unit='D') expected_x = xr.DataArray( npcompat.gradient( - da, utils.to_numeric(da['x'], datetime_unit='D'), + da, utils.datetime_to_numeric(da['x'], datetime_unit='D'), axis=0, edge_order=1), dims=da.dims, coords=da.coords) assert_equal(expected_x, actual) @@ -4595,3 +4596,32 @@ def test_gradient_datetime(dask): coords={'x': coord}) actual = da.differentiate('x', edge_order=1) assert np.allclose(actual, 1.0) + + +@pytest.mark.skipif(not has_cftime, reason='Test requires cftime.') +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_cftime(dask): + rs = np.random.RandomState(42) + coord = xr.cftime_range('2000', periods=8, freq='2M') + + da = xr.DataArray( + rs.randn(8, 6), + coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))}, + dims=['time', 'y']) + + if dask and has_dask: + da = da.chunk({'time': 4}) + + actual = da.differentiate('time', edge_order=1, datetime_unit='D') + expected_data = npcompat.gradient( + da, utils.datetime_to_numeric(da['time'], datetime_unit='D'), + axis=0, edge_order=1) + expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims) + assert_equal(expected, actual) + + actual2 = da.differentiate('time', edge_order=1, datetime_unit='h') + assert_allclose(actual, actual2 * 24) + + # Test the differentiation of datetimes themselves + actual = da['time'].differentiate('time', edge_order=1, datetime_unit='D') + assert_allclose(actual, xr.ones_like(da['time']).astype(float)) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 4a8f4e6eedf..0778a1ff128 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -5,10 +5,13 @@ import pytest import xarray as xr -from xarray.tests import assert_allclose, assert_equal, requires_scipy +from xarray.tests import (assert_allclose, assert_equal, requires_cftime, + requires_scipy) from . import has_dask, has_scipy from .test_dataset import create_test_data +from ..coding.cftimeindex import _parse_array_of_cftime_strings + try: import scipy except ImportError: @@ -490,3 +493,83 @@ def test_datetime_single_string(): expected = xr.DataArray(0.5) assert_allclose(actual.drop('time'), expected) + + +@requires_cftime +@requires_scipy +def test_cftime(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D') + actual = da.interp(time=times_new) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_type_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D', + calendar='noleap') + with pytest.raises(TypeError): + da.interp(time=times_new) + + +@requires_cftime +@requires_scipy +def test_cftime_list_of_strings(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = ['2000-01-01T12:00', '2000-01-02T12:00', '2000-01-03T12:00'] + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], + dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_single_string(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = '2000-01-01T12:00' + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray(0.5, coords={'time': times_new_array}) + + assert_allclose(actual, expected) + + +@requires_scipy +def test_datetime_to_non_datetime_error(): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + with pytest.raises(TypeError): + da.interp(time=0.5) + + +@requires_cftime +@requires_scipy +def test_cftime_to_non_cftime_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + with pytest.raises(TypeError): + da.interp(time=0.5) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ed8045b78e4..0c0e0f3f744 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -5,16 +5,18 @@ import numpy as np import pandas as pd import pytest +import xarray as xr from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils from xarray.core.options import set_options from xarray.core.pycompat import OrderedDict from xarray.core.utils import either_dict_or_kwargs +from xarray.testing import assert_identical from . import ( TestCase, assert_array_equal, has_cftime, has_cftime_or_netCDF4, - requires_dask) + requires_dask, requires_cftime) from .test_coding_times import _all_cftime_date_types @@ -263,3 +265,42 @@ def test_either_dict_or_kwargs(): with pytest.raises(ValueError, match=r'foo'): result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo') + + +def test_datetime_to_numeric_datetime64(): + times = pd.date_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) + + +@requires_cftime +def test_datetime_to_numeric_cftime(): + times = xr.cftime_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) From 23d1cda3b7da5c73a5f561a5c953b50beaa2bfe6 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 28 Sep 2018 20:24:35 +0200 Subject: [PATCH 37/51] fix:2445 (#2446) * fix:2445 * rename test_shift_multidim -> test_roll_multidim --- doc/whats-new.rst | 4 ++++ xarray/core/dataset.py | 3 ++- xarray/tests/test_dataset.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a5b7b36142e..8b145924f2d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -56,6 +56,10 @@ Enhancements Bug fixes ~~~~~~~~~ +- ``xarray.DataArray.roll`` correctly handles multidimensional arrays. + (:issue:`2445`) + By `Keisuke Fujii `_. + - ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. (:issue:`2240`) By `Keisuke Fujii `_. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4ad5902ebad..5e787c1587b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3595,7 +3595,8 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): variables = OrderedDict() for k, v in iteritems(self.variables): if k not in unrolled_vars: - variables[k] = v.roll(**shifts) + variables[k] = v.roll(**{k: s for k, s in shifts.items() + if k in v.dims}) else: variables[k] = v diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c42b84c05fc..2c964b81b98 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3945,6 +3945,16 @@ def test_roll_coords_none(self): expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) + def test_roll_multidim(self): + # regression test for 2445 + arr = xr.DataArray( + [[1, 2, 3],[4, 5, 6]], coords={'x': range(3), 'y': range(2)}, + dims=('y','x')) + actual = arr.roll(x=1, roll_coords=True) + expected = xr.DataArray([[3, 1, 2],[6, 4, 5]], + coords=[('y', [0, 1]), ('x', [2, 0, 1])]) + assert_identical(expected, actual) + def test_real_and_imag(self): attrs = {'foo': 'bar'} ds = Dataset({'x': ((), 1 + 2j, attrs)}, attrs=attrs) From f9c4169150286fa1aac020ab965380ed21fe1148 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Sun, 30 Sep 2018 09:16:48 -0400 Subject: [PATCH 38/51] Fix FutureWarning in CFTimeIndex.date_type (#2448) --- xarray/coding/cftimeindex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index e236dca3693..faf1a044505 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -157,7 +157,7 @@ def f(self): def get_date_type(self): - if self.data: + if self._data.size: return type(self._data[0]) else: return None From 8fb57f7b9ff683225650a928b8d7d287d8954e79 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Tue, 2 Oct 2018 10:44:29 -0400 Subject: [PATCH 39/51] Add CFTimeIndex.shift (#2431) * Add CFTimeIndex.shift Update what's new Add bug fix note * Add example to docstring * Use pycompat for basestring * Generate an API reference page for CFTimeIndex.shift --- doc/api-hidden.rst | 2 + doc/whats-new.rst | 5 +++ xarray/coding/cftimeindex.py | 51 ++++++++++++++++++++++++ xarray/tests/test_cftimeindex.py | 66 ++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 1826cc86892..0e8143c72ea 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -151,3 +151,5 @@ plot.FacetGrid.set_titles plot.FacetGrid.set_ticks plot.FacetGrid.map + + CFTimeIndex.shift diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8b145924f2d..3f8d40910cd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,6 +46,9 @@ Enhancements - Added support for Python 3.7. (:issue:`2271`). By `Joe Hamman `_. +- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a + CFTimeIndex by a specified frequency. (:issue:`2244`). By `Spencer Clark + `_. - Added support for using ``cftime.datetime`` coordinates with :py:meth:`~xarray.DataArray.differentiate`, :py:meth:`~xarray.Dataset.differentiate`, @@ -56,6 +59,8 @@ Enhancements Bug fixes ~~~~~~~~~ +- Addition and subtraction operators used with a CFTimeIndex now preserve the + index's type. (:issue:`2244`). By `Spencer Clark `_. - ``xarray.DataArray.roll`` correctly handles multidimensional arrays. (:issue:`2445`) By `Keisuke Fujii `_. diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index faf1a044505..341ecfed262 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -315,6 +315,57 @@ def contains(self, key): """Needed for .loc based partial-string indexing""" return self.__contains__(key) + def shift(self, n, freq): + """Shift the CFTimeIndex a multiple of the given frequency. + + See the documentation for :py:func:`~xarray.cftime_range` for a + complete listing of valid frequency strings. + + Parameters + ---------- + n : int + Periods to shift by + freq : str or datetime.timedelta + A frequency string or datetime.timedelta object to shift by + + Returns + ------- + CFTimeIndex + + See also + -------- + pandas.DatetimeIndex.shift + + Examples + -------- + >>> index = xr.cftime_range('2000', periods=1, freq='M') + >>> index + CFTimeIndex([2000-01-31 00:00:00], dtype='object') + >>> index.shift(1, 'M') + CFTimeIndex([2000-02-29 00:00:00], dtype='object') + """ + from .cftime_offsets import to_offset + + if not isinstance(n, int): + raise TypeError("'n' must be an int, got {}.".format(n)) + if isinstance(freq, timedelta): + return self + n * freq + elif isinstance(freq, pycompat.basestring): + return self + n * to_offset(freq) + else: + raise TypeError( + "'freq' must be of type " + "str or datetime.timedelta, got {}.".format(freq)) + + def __add__(self, other): + return CFTimeIndex(np.array(self) + other) + + def __radd__(self, other): + return CFTimeIndex(other + np.array(self)) + + def __sub__(self, other): + return CFTimeIndex(np.array(self) - other) + def _parse_iso8601_without_reso(date_type, datetime_str): date, _ = _parse_iso8601_with_reso(date_type, datetime_str) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 62a29a15247..a558ab9a784 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -619,6 +619,72 @@ def test_empty_cftimeindex(): assert index.date_type is None +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_add(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_radd(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = timedelta(days=1) + index + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_sub(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_rsub(index): + with pytest.raises(TypeError): + timedelta(days=1) - index + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('freq', ['D', timedelta(days=1)]) +def test_cftimeindex_shift(index, freq): + date_type = index.date_type + expected_dates = [date_type(1, 1, 3), date_type(1, 2, 3), + date_type(2, 1, 3), date_type(2, 2, 3)] + expected = CFTimeIndex(expected_dates) + result = index.shift(2, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_n(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift('a', 'D') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_freq(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift(1, 1) + + @requires_cftime def test_parse_array_of_cftime_strings(): from cftime import DatetimeNoLeap From 1e7a1d348d927bfa4fd4fba58a3f7600314746cf Mon Sep 17 00:00:00 2001 From: Denis Rykov Date: Tue, 2 Oct 2018 17:05:25 +0200 Subject: [PATCH 40/51] np.AxisError was added in numpy 1.13 (#2455) --- xarray/core/dask_array_compat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 5e6b81a253d..2196dba7f86 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -44,7 +44,13 @@ def isin(element, test_elements, assume_unique=False, invert=False): import math from numbers import Integral, Real - AxisError = np.AxisError + try: + AxisError = np.AxisError + except AttributeError: + try: + np.array([0]).sum(axis=5) + except Exception as e: + AxisError = type(e) def validate_axis(axis, ndim): """ Validate an input to axis= keywords """ From 0f70a876759197388d32d6d9f0317f0fe63e0336 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 3 Oct 2018 00:45:59 +0900 Subject: [PATCH 41/51] plot.contour: Don't make cmap if colors is a single color. (#2453) By default, matplotlib draw dashed negative contours for a single color. We lost this feature by manually specifying cmap everytime. --- doc/whats-new.rst | 4 ++++ xarray/plot/plot.py | 5 +++++ xarray/tests/test_plot.py | 4 ++-- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3f8d40910cd..8e0526f8b8b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -68,6 +68,10 @@ Bug fixes - ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. (:issue:`2240`) By `Keisuke Fujii `_. +- Restore matplotlib's default of plotting dashed negative contours when + a single color is passed to ``DataArray.contour()`` e.g. ``colors='k'``. + By `Deepak Cherian `_. + .. _whats-new.0.10.9: diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index a6add44682f..3f9f1090c70 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -737,6 +737,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # pcolormesh kwargs['extend'] = cmap_params['extend'] kwargs['levels'] = cmap_params['levels'] + # if colors == a single color, matplotlib draws dashed negative + # contours. we lose this feature if we pass cmap and not colors + if isinstance(colors, basestring): + cmap_params['cmap'] = None + kwargs['colors'] = colors if 'pcolormesh' == plotfunc.__name__: kwargs['infer_intervals'] = infer_intervals diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 1423f7ae853..98265149122 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1140,9 +1140,9 @@ def test_colors(self): def _color_as_tuple(c): return tuple(c[:3]) + # with single color, we don't want rgb array artist = self.plotmethod(colors='k') - assert _color_as_tuple(artist.cmap.colors[0]) == \ - (0.0, 0.0, 0.0) + assert artist.cmap.colors[0] == 'k' artist = self.plotmethod(colors=['k', 'b']) assert _color_as_tuple(artist.cmap.colors[1]) == \ From 3cef8d730d5bbd699a393fa15266064ebb9849e2 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Fri, 5 Oct 2018 04:02:17 -0400 Subject: [PATCH 42/51] Clean up _parse_array_of_cftime_strings (#2464) * Make _parse_array_of_cftime_strings more robust * lint --- xarray/coding/cftimeindex.py | 7 ++----- xarray/tests/test_cftimeindex.py | 8 +++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 341ecfed262..75a1fc9bd1a 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -389,8 +389,5 @@ def _parse_array_of_cftime_strings(strings, date_type): ------- np.array """ - if strings.ndim == 0: - return np.array(_parse_iso8601_without_reso(date_type, strings.item())) - else: - return np.array([_parse_iso8601_without_reso(date_type, s) - for s in strings]) + return np.array([_parse_iso8601_without_reso(date_type, s) + for s in strings.ravel()]).reshape(strings.shape) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index a558ab9a784..33bf2cbce0d 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -689,9 +689,11 @@ def test_cftimeindex_shift_invalid_freq(): def test_parse_array_of_cftime_strings(): from cftime import DatetimeNoLeap - strings = np.array(['2000-01-01', '2000-01-02']) - expected = np.array([DatetimeNoLeap(2000, 1, 1), - DatetimeNoLeap(2000, 1, 2)]) + strings = np.array([['2000-01-01', '2000-01-02'], + ['2000-01-03', '2000-01-04']]) + expected = np.array( + [[DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], + [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)]]) result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) np.testing.assert_array_equal(result, expected) From 3f697fe013dc510cebb6a64d0a2c760d6320573a Mon Sep 17 00:00:00 2001 From: Jie Chen Date: Fri, 5 Oct 2018 08:35:56 -0700 Subject: [PATCH 43/51] Update whats-new.rst (#2466) --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4a6886159d1..662d60e29e7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,7 +27,7 @@ What's New .. _whats-new.0.10.9: -v0.10.9 (21 September 2019) +v0.10.9 (21 September 2018) --------------------------- This minor release contains a number of backwards compatible enhancements. From 4a7a103204989af7e2b6bc97a4109d81beebd34c Mon Sep 17 00:00:00 2001 From: David Hoese Date: Sat, 6 Oct 2018 02:48:57 -0500 Subject: [PATCH 44/51] Add python_requires to setup.py (#2465) --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 88c27c95118..68798bdf219 100644 --- a/setup.py +++ b/setup.py @@ -69,5 +69,6 @@ install_requires=INSTALL_REQUIRES, tests_require=TESTS_REQUIRE, url=URL, + python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', packages=find_packages(), package_data={'xarray': ['tests/data/*']}) From bb87a9441d22b390e069d0fde58f297a054fd98a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 6 Oct 2018 13:09:13 -0400 Subject: [PATCH 45/51] Replace the last of unittest with pytest (#2467) * cleaning * remove assertEqual * remove assertItems * more removing assertitems * remove assertequal * remove TestCase * straggler * pep8replies * pep8replies2 * small flups * pytest.warns requires Warning class * tuple list comparisons * disable test check * set / list * the last unittest survivor, and automated formatting changes where possible --- xarray/tests/__init__.py | 40 +--- xarray/tests/test_accessors.py | 9 +- xarray/tests/test_backends.py | 281 ++++++++++++++-------------- xarray/tests/test_combine.py | 10 +- xarray/tests/test_conventions.py | 23 +-- xarray/tests/test_dask.py | 12 +- xarray/tests/test_dataarray.py | 64 ++++--- xarray/tests/test_dataset.py | 81 ++++---- xarray/tests/test_duck_array_ops.py | 12 +- xarray/tests/test_extensions.py | 4 +- xarray/tests/test_formatting.py | 8 +- xarray/tests/test_indexing.py | 18 +- xarray/tests/test_merge.py | 10 +- xarray/tests/test_plot.py | 61 +++--- xarray/tests/test_tutorial.py | 8 +- xarray/tests/test_utils.py | 23 +-- xarray/tests/test_variable.py | 48 ++--- 17 files changed, 346 insertions(+), 366 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 33a8da6bbfb..285c1f03a26 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -9,11 +9,10 @@ import numpy as np from numpy.testing import assert_array_equal # noqa: F401 -from xarray.core.duck_array_ops import allclose_or_equiv +from xarray.core.duck_array_ops import allclose_or_equiv # noqa import pytest from xarray.core import utils -from xarray.core.pycompat import PY3 from xarray.core.indexing import ExplicitlyIndexed from xarray.testing import (assert_equal, assert_identical, # noqa: F401 assert_allclose) @@ -25,10 +24,6 @@ # old location, for pandas < 0.20 from pandas.util.testing import assert_frame_equal # noqa: F401 -try: - import unittest2 as unittest -except ImportError: - import unittest try: from unittest import mock @@ -117,39 +112,6 @@ def _importorskip(modname, minversion=None): "internet connection") -class TestCase(unittest.TestCase): - """ - These functions are all deprecated. Instead, use functions in xr.testing - """ - if PY3: - # Python 3 assertCountEqual is roughly equivalent to Python 2 - # assertItemsEqual - def assertItemsEqual(self, first, second, msg=None): - __tracebackhide__ = True # noqa: F841 - return self.assertCountEqual(first, second, msg) - - @contextmanager - def assertWarns(self, message): - __tracebackhide__ = True # noqa: F841 - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', message) - yield - assert len(w) > 0 - assert any(message in str(wi.message) for wi in w) - - def assertVariableNotEqual(self, v1, v2): - __tracebackhide__ = True # noqa: F841 - assert not v1.equals(v2) - - def assertEqual(self, a1, a2): - __tracebackhide__ = True # noqa: F841 - assert a1 == a2 or (a1 != a1 and a2 != a2) - - def assertAllClose(self, a1, a2, rtol=1e-05, atol=1e-8): - __tracebackhide__ = True # noqa: F841 - assert allclose_or_equiv(a1, a2, rtol=rtol, atol=atol) - - @contextmanager def raises_regex(error, pattern): __tracebackhide__ = True # noqa: F841 diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py index e1b3a95b942..38038fc8f65 100644 --- a/xarray/tests/test_accessors.py +++ b/xarray/tests/test_accessors.py @@ -7,12 +7,13 @@ import xarray as xr from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, requires_dask, - has_cftime, has_dask, has_cftime_or_netCDF4) + assert_array_equal, assert_equal, has_cftime, has_cftime_or_netCDF4, + has_dask, raises_regex, requires_dask) -class TestDatetimeAccessor(TestCase): - def setUp(self): +class TestDatetimeAccessor(object): + @pytest.fixture(autouse=True) + def setup(self): nt = 100 data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 8b469761ccd..a2e1cb4c0fa 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2,6 +2,7 @@ import contextlib import itertools +import math import os.path import pickle import shutil @@ -19,8 +20,8 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import (robust_getitem, - PickleByReconstructionWrapper) +from xarray.backends.common import ( + PickleByReconstructionWrapper, robust_getitem) from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing @@ -29,12 +30,11 @@ from xarray.tests import mock from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex, + assert_allclose, assert_array_equal, assert_equal, assert_identical, + has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cftime, requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib, - requires_pydap, requires_pynio, requires_rasterio, requires_scipy, - requires_scipy_or_netCDF4, requires_zarr, requires_pseudonetcdf, - requires_cftime) + requires_pseudonetcdf, requires_pydap, requires_pynio, requires_rasterio, + requires_scipy, requires_scipy_or_netCDF4, requires_zarr) from .test_dataset import create_test_data try: @@ -106,7 +106,7 @@ def create_boolean_data(): return Dataset({'x': ('t', [True, False, False, True], attributes)}) -class TestCommon(TestCase): +class TestCommon(object): def test_robust_getitem(self): class UnreliableArrayFailure(Exception): @@ -126,11 +126,11 @@ def __getitem__(self, key): array = UnreliableArray([0]) with pytest.raises(UnreliableArrayFailure): array[0] - self.assertEqual(array[0], 0) + assert array[0] == 0 actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0) - self.assertEqual(actual, 0) + assert actual == 0 class NetCDF3Only(object): @@ -222,11 +222,11 @@ def assert_loads(vars=None): with self.roundtrip(expected) as actual: for k, v in actual.variables.items(): # IndexVariables are eagerly loaded into memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) yield actual for k, v in actual.variables.items(): if k in vars: - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) with pytest.raises(AssertionError): @@ -252,14 +252,14 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) for v in computed.variables.values(): - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) assert_identical(expected, computed) @@ -343,12 +343,12 @@ def test_roundtrip_string_encoded_characters(self): expected['x'].encoding['dtype'] = 'S1' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'utf-8') + assert actual['x'].encoding['_Encoding'] == 'utf-8' expected['x'].encoding['_Encoding'] = 'ascii' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'ascii') + assert actual['x'].encoding['_Encoding'] == 'ascii' def test_roundtrip_numpy_datetime_data(self): times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) @@ -434,10 +434,10 @@ def test_roundtrip_coordinates_with_space(self): def test_roundtrip_boolean_dtype(self): original = create_boolean_data() - self.assertEqual(original['x'].dtype, 'bool') + assert original['x'].dtype == 'bool' with self.roundtrip(original) as actual: assert_identical(original, actual) - self.assertEqual(actual['x'].dtype, 'bool') + assert actual['x'].dtype == 'bool' def test_orthogonal_indexing(self): in_memory = create_test_data() @@ -626,20 +626,20 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded = create_encoded_unsigned_masked_scaled_data() with self.roundtrip(decoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert (decoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) # make sure roundtrip encoding didn't change the # original dataset. @@ -647,14 +647,14 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded, create_encoded_unsigned_masked_scaled_data()) with self.roundtrip(encoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert decoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert encoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(encoded, actual, decode_bytes=False) def test_roundtrip_mask_and_scale(self): @@ -692,12 +692,11 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds['temp'].attrs['coordinates'])) - self.assertTrue( - equals_latlon(ds['precip'].attrs['coordinates'])) - self.assertNotIn('coordinates', ds.attrs) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds['temp'].attrs['coordinates']) + assert equals_latlon(ds['precip'].attrs['coordinates']) + assert 'coordinates' not in ds.attrs + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs modified = original.drop(['temp', 'precip']) with self.roundtrip(modified) as actual: @@ -705,9 +704,9 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: modified.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds.attrs['coordinates'])) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds.attrs['coordinates']) + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs def test_roundtrip_endian(self): ds = Dataset({'x': np.arange(3, 10, dtype='>i2'), @@ -743,8 +742,8 @@ def test_encoding_kwarg(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} kwargs = dict(encoding={'x': {'foo': 'bar'}}) with raises_regex(ValueError, 'unexpected encoding'): @@ -766,7 +765,7 @@ def test_encoding_kwarg_dates(self): units = 'days since 1900-01-01' kwargs = dict(encoding={'t': {'units': units}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.t.encoding['units'], units) + assert actual.t.encoding['units'] == units assert_identical(actual, ds) def test_encoding_kwarg_fixed_width_string(self): @@ -778,7 +777,7 @@ def test_encoding_kwarg_fixed_width_string(self): ds = Dataset({'x': strings}) kwargs = dict(encoding={'x': {'dtype': 'S1'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual['x'].encoding['dtype'], 'S1') + assert actual['x'].encoding['dtype'] == 'S1' assert_identical(actual, ds) def test_default_fill_value(self): @@ -786,9 +785,8 @@ def test_default_fill_value(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['_FillValue'], - np.nan) - self.assertEqual(ds.x.encoding, {}) + assert math.isnan(actual.x.encoding['_FillValue']) + assert ds.x.encoding == {} # Test default encoding for int: ds = Dataset({'x': ('y', np.arange(10.0))}) @@ -797,14 +795,14 @@ def test_default_fill_value(self): warnings.filterwarnings( 'ignore', '.*floating point data as an integer') with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} # Test default encoding for implicit int: ds = Dataset({'x': ('y', np.arange(10, dtype='int16'))}) with self.roundtrip(ds) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} def test_explicitly_omit_fill_value(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}) @@ -817,7 +815,7 @@ def test_explicitly_omit_fill_value_via_encoding_kwarg(self): kwargs = dict(encoding={'x': {'_FillValue': None}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert '_FillValue' not in actual.x.encoding - self.assertEqual(ds.y.encoding, {}) + assert ds.y.encoding == {} def test_explicitly_omit_fill_value_in_coord(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}, coords={'y': [0.0, 1.0]}) @@ -830,14 +828,14 @@ def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self): kwargs = dict(encoding={'y': {'_FillValue': None}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert '_FillValue' not in actual.y.encoding - self.assertEqual(ds.y.encoding, {}) + assert ds.y.encoding == {} def test_encoding_same_dtype(self): ds = Dataset({'x': ('y', np.arange(10.0, dtype='f4'))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} def test_append_write(self): # regression for GH1215 @@ -1015,7 +1013,7 @@ def test_default_to_char_arrays(self): data = Dataset({'x': np.array(['foo', 'zzzz'], dtype='S')}) with self.roundtrip(data) as actual: assert_identical(data, actual) - self.assertEqual(actual['x'].dtype, np.dtype('S4')) + assert actual['x'].dtype == np.dtype('S4') def test_open_encodings(self): # Create a netCDF file with explicit time units @@ -1040,15 +1038,15 @@ def test_open_encodings(self): actual_encoding = dict((k, v) for k, v in iteritems(actual['time'].encoding) if k in expected['time'].encoding) - self.assertDictEqual(actual_encoding, - expected['time'].encoding) + assert actual_encoding == \ + expected['time'].encoding def test_dump_encodings(self): # regression test for #709 ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'zlib': True}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue(actual.x.encoding['zlib']) + assert actual.x.encoding['zlib'] def test_dump_and_open_encodings(self): # Create a netCDF file with explicit time units @@ -1066,8 +1064,7 @@ def test_dump_and_open_encodings(self): with create_tmp_file() as tmp_file2: xarray_dataset.to_netcdf(tmp_file2) with nc4.Dataset(tmp_file2, 'r') as ds: - self.assertEqual( - ds.variables['time'].getncattr('units'), units) + assert ds.variables['time'].getncattr('units') == units assert_array_equal( ds.variables['time'], np.arange(10) + 4) @@ -1080,7 +1077,7 @@ def test_compression_encoding(self): 'original_shape': data.var2.shape}) with self.roundtrip(data) as actual: for k, v in iteritems(data['var2'].encoding): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] # regression test for #156 expected = data.isel(dim1=0) @@ -1095,14 +1092,14 @@ def test_encoding_kwarg_compression(self): with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert_equal(actual, ds) - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(actual.x.encoding['zlib'], True) - self.assertEqual(actual.x.encoding['complevel'], 9) - self.assertEqual(actual.x.encoding['fletcher32'], True) - self.assertEqual(actual.x.encoding['chunksizes'], (5,)) - self.assertEqual(actual.x.encoding['shuffle'], True) + assert actual.x.encoding['dtype'] == 'f4' + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 + assert actual.x.encoding['fletcher32'] + assert actual.x.encoding['chunksizes'] == (5,) + assert actual.x.encoding['shuffle'] - self.assertEqual(ds.x.encoding, {}) + assert ds.x.encoding == {} def test_encoding_chunksizes_unlimited(self): # regression test for GH1225 @@ -1183,7 +1180,7 @@ def test_read_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, TestCase): +class NetCDF4DataTest(BaseNetCDF4Test): autoclose = False @contextlib.contextmanager @@ -1201,7 +1198,7 @@ def test_variable_order(self): ds.coords['c'] = 4 with self.roundtrip(ds) as actual: - self.assertEqual(list(ds.variables), list(actual.variables)) + assert list(ds.variables) == list(actual.variables) def test_unsorted_index_raises(self): # should be fixed in netcdf4 v1.2.1 @@ -1220,7 +1217,7 @@ def test_unsorted_index_raises(self): try: ds2.randovar.values except IndexError as err: - self.assertIn('first by calling .load', str(err)) + assert 'first by calling .load' in str(err) def test_88_character_filename_segmentation_fault(self): # should be fixed in netcdf4 v1.3.1 @@ -1335,17 +1332,17 @@ def test_auto_chunk(self): original, open_kwargs={'auto_chunk': False}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # there should be no chunks - self.assertEqual(v.chunks, None) + assert v.chunks is None with self.roundtrip( original, open_kwargs={'auto_chunk': True}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # chunk size should be the same as original - self.assertEqual(v.chunks, original[k].chunks) + assert v.chunks == original[k].chunks def test_write_uneven_dask_chunks(self): # regression for GH#2225 @@ -1365,7 +1362,7 @@ def test_chunk_encoding(self): data['var2'].encoding.update({'chunks': chunks}) with self.roundtrip(data) as actual: - self.assertEqual(chunks, actual['var2'].encoding['chunks']) + assert chunks == actual['var2'].encoding['chunks'] # expect an error with non-integer chunks data['var2'].encoding.update({'chunks': (5, 4.5)}) @@ -1382,7 +1379,7 @@ def test_chunk_encoding_with_dask(self): # zarr automatically gets chunk information from dask chunks ds_chunk4 = ds.chunk({'x': 4}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) + assert (4,) == actual['var1'].encoding['chunks'] # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({'x': (5, 4, 3)}) @@ -1395,14 +1392,14 @@ def test_chunk_encoding_with_dask(self): # ... except if the last chunk is smaller than the first ds_chunk_irreg = ds.chunk({'x': (5, 5, 2)}) with self.roundtrip(ds_chunk_irreg) as actual: - self.assertEqual((5,), actual['var1'].encoding['chunks']) + assert (5,) == actual['var1'].encoding['chunks'] # - encoding specified - # specify compatible encodings for chunk_enc in 4, (4, ): ds_chunk4['var1'].encoding.update({'chunks': chunk_enc}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) + assert (4,) == actual['var1'].encoding['chunks'] # TODO: remove this failure once syncronized overlapping writes are # supported by xarray @@ -1532,14 +1529,14 @@ def test_encoding_chunksizes(self): @requires_zarr -class ZarrDictStoreTest(BaseZarrTest, TestCase): +class ZarrDictStoreTest(BaseZarrTest): @contextlib.contextmanager def create_zarr_target(self): yield {} @requires_zarr -class ZarrDirectoryStoreTest(BaseZarrTest, TestCase): +class ZarrDirectoryStoreTest(BaseZarrTest): @contextlib.contextmanager def create_zarr_target(self): with create_tmp_file(suffix='.zarr') as tmp: @@ -1562,7 +1559,7 @@ def test_append_overwrite_values(self): @requires_scipy -class ScipyInMemoryDataTest(ScipyWriteTest, TestCase): +class ScipyInMemoryDataTest(ScipyWriteTest): engine = 'scipy' @contextlib.contextmanager @@ -1588,7 +1585,7 @@ class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): @requires_scipy -class ScipyFileObjectTest(ScipyWriteTest, TestCase): +class ScipyFileObjectTest(ScipyWriteTest): engine = 'scipy' @contextlib.contextmanager @@ -1616,7 +1613,7 @@ def test_pickle_dataarray(self): @requires_scipy -class ScipyFilePathTest(ScipyWriteTest, TestCase): +class ScipyFilePathTest(ScipyWriteTest): engine = 'scipy' @contextlib.contextmanager @@ -1640,7 +1637,7 @@ def test_netcdf3_endianness(self): # regression test for GH416 expected = open_example_dataset('bears.nc', engine='scipy') for var in expected.variables.values(): - self.assertTrue(var.dtype.isnative) + assert var.dtype.isnative @requires_netCDF4 def test_nc4_scipy(self): @@ -1657,7 +1654,7 @@ class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only): engine = 'netcdf4' file_format = 'NETCDF3_CLASSIC' @@ -1682,7 +1679,7 @@ class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, - TestCase): + object): engine = 'netcdf4' file_format = 'NETCDF4_CLASSIC' @@ -1700,7 +1697,7 @@ class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed file_format = 'netcdf3_64bit' @@ -1754,24 +1751,24 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) # Regression test for https://github.com/pydata/xarray/issues/2134 with self.roundtrip(ds, save_kwargs=dict(unlimited_dims='y')) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) # Regression test for https://github.com/pydata/xarray/issues/2134 ds.encoding = {'unlimited_dims': 'y'} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) @@ -1781,7 +1778,7 @@ class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): +class H5NetCDFDataTest(BaseNetCDF4Test): engine = 'h5netcdf' @contextlib.contextmanager @@ -1822,11 +1819,11 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) def test_compression_encoding_h5py(self): @@ -1857,7 +1854,7 @@ def test_compression_encoding_h5py(self): compr_out.update(compr_common) with self.roundtrip(data) as actual: for k, v in compr_out.items(): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] def test_compression_check_encoding_h5py(self): """When mismatched h5py and NetCDF4-Python encodings are expressed @@ -1898,14 +1895,14 @@ def test_dump_encodings_h5py(self): kwargs = {'encoding': {'x': { 'compression': 'gzip', 'compression_opts': 9}}} with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['zlib'], True) - self.assertEqual(actual.x.encoding['complevel'], 9) + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 kwargs = {'encoding': {'x': { 'compression': 'lzf', 'compression_opts': None}}} with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['compression'], 'lzf') - self.assertEqual(actual.x.encoding['compression_opts'], None) + assert actual.x.encoding['compression'] == 'lzf' + assert actual.x.encoding['compression_opts'] is None # tests pending h5netcdf fix @@ -1985,7 +1982,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, @requires_scipy_or_netCDF4 -class OpenMFDatasetWithDataVarsAndCoordsKwTest(TestCase): +class OpenMFDatasetWithDataVarsAndCoordsKwTest(object): coord_name = 'lon' var_name = 'v1' @@ -2056,9 +2053,9 @@ def test_common_coord_when_datavars_all(self): var_shape = ds[self.var_name].shape - self.assertEqual(var_shape, coord_shape) - self.assertNotEqual(coord_shape1, coord_shape) - self.assertNotEqual(coord_shape2, coord_shape) + assert var_shape == coord_shape + assert coord_shape1 != coord_shape + assert coord_shape2 != coord_shape def test_common_coord_when_datavars_minimal(self): opt = 'minimal' @@ -2073,9 +2070,9 @@ def test_common_coord_when_datavars_minimal(self): var_shape = ds[self.var_name].shape - self.assertNotEqual(var_shape, coord_shape) - self.assertEqual(coord_shape1, coord_shape) - self.assertEqual(coord_shape2, coord_shape) + assert var_shape != coord_shape + assert coord_shape1 == coord_shape + assert coord_shape2 == coord_shape def test_invalid_data_vars_value_should_fail(self): @@ -2093,7 +2090,7 @@ def test_invalid_data_vars_value_should_fail(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, DatasetIOTestCases): +class DaskTest(DatasetIOTestCases): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -2133,10 +2130,10 @@ def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): with xr.set_options(enable_cftimeindex=True): with self.roundtrip(expected) as actual: abs_diff = abs(actual.t.values - expected_decoded_t) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() abs_diff = abs(actual.t0.values - expected_decoded_t0) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): # Override method in DatasetIOTestCases - remove not applicable @@ -2153,10 +2150,10 @@ def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): with xr.set_options(enable_cftimeindex=False): with self.roundtrip(expected) as actual: abs_diff = abs(actual.t.values - expected_decoded_t) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() abs_diff = abs(actual.t0.values - expected_decoded_t0) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() def test_write_store(self): # Override method in DatasetIOTestCases - not applicable to dask @@ -2177,14 +2174,14 @@ def test_open_mfdataset(self): original.isel(x=slice(5, 10)).to_netcdf(tmp2) with open_mfdataset([tmp1, tmp2], autoclose=self.autoclose) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, - ((5, 5),)) + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == \ + ((5, 5),) assert_identical(original, actual) with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, autoclose=self.autoclose) as actual: - self.assertEqual(actual.foo.variable.data.chunks, - ((3, 2, 3, 2),)) + assert actual.foo.variable.data.chunks == \ + ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) @@ -2218,7 +2215,7 @@ def test_attrs_mfdataset(self): with open_mfdataset([tmp1, tmp2]) as actual: # presumes that attributes inherited from # first dataset loaded - self.assertEqual(actual.test1, ds1.test1) + assert actual.test1 == ds1.test1 # attributes from ds2 are not retained, e.g., with raises_regex(AttributeError, 'no attribute'): @@ -2298,13 +2295,13 @@ def test_open_dataset(self): with create_tmp_file() as tmp: original.to_netcdf(tmp) with open_dataset(tmp, chunks={'x': 5}) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),)) + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5),) assert_identical(original, actual) with open_dataset(tmp, chunks=5) as actual: assert_identical(original, actual) with open_dataset(tmp) as actual: - self.assertIsInstance(actual.foo.variable.data, np.ndarray) + assert isinstance(actual.foo.variable.data, np.ndarray) assert_identical(original, actual) def test_open_single_dataset(self): @@ -2344,9 +2341,9 @@ def test_deterministic_names(self): repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): - self.assertIn(var_name, dask_name) - self.assertEqual(dask_name[:13], 'open_dataset-') - self.assertEqual(original_names, repeat_names) + assert var_name in dask_name + assert dask_name[:13] == 'open_dataset-' + assert original_names == repeat_names def test_dataarray_compute(self): # Test DataArray.compute() on dask backend. @@ -2354,8 +2351,8 @@ def test_dataarray_compute(self): # however dask is the only tested backend which supports DataArrays actual = DataArray([1, 2]).chunk() computed = actual.compute() - self.assertFalse(actual._in_memory) - self.assertTrue(computed._in_memory) + assert not actual._in_memory + assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) def test_to_netcdf_compute_false_roundtrip(self): @@ -2395,7 +2392,7 @@ class DaskTestAutocloseTrue(DaskTest): @requires_scipy_or_netCDF4 @requires_pydap -class PydapTest(TestCase): +class PydapTest(object): def convert_to_pydap_dataset(self, original): from pydap.model import GridType, BaseType, DatasetType ds = DatasetType('bears', **original.attrs) @@ -2427,8 +2424,8 @@ def test_cmp_local_file(self): assert_equal(actual, expected) # global attributes should be global attributes on the dataset - self.assertNotIn('NC_GLOBAL', actual.attrs) - self.assertIn('history', actual.attrs) + assert 'NC_GLOBAL' not in actual.attrs + assert 'history' in actual.attrs # we don't check attributes exactly with assertDatasetIdentical() # because the test DAP server seems to insert some extra @@ -2436,8 +2433,7 @@ def test_cmp_local_file(self): assert actual.attrs.keys() == expected.attrs.keys() with self.create_datasets() as (actual, expected): - assert_equal( - actual.isel(l=2), expected.isel(l=2)) # noqa: E741 + assert_equal(actual.isel(l=2), expected.isel(l=2)) # noqa with self.create_datasets() as (actual, expected): assert_equal(actual.isel(i=0, j=-1), @@ -2497,7 +2493,7 @@ def test_session(self): @requires_scipy @requires_pynio -class PyNioTest(ScipyWriteTest, TestCase): +class PyNioTest(ScipyWriteTest): def test_write_store(self): # pynio is read-only for now pass @@ -2529,7 +2525,7 @@ class PyNioTestAutocloseTrue(PyNioTest): @requires_pseudonetcdf @pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') -class PseudoNetCDFFormatTest(TestCase): +class PseudoNetCDFFormatTest(object): autoclose = True def open(self, path, **kwargs): @@ -2792,7 +2788,7 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, @requires_rasterio -class TestRasterio(TestCase): +class TestRasterio(object): @requires_scipy_or_netCDF4 def test_serialization(self): @@ -2837,7 +2833,8 @@ def test_non_rectilinear(self): assert len(rioda.attrs['transform']) == 6 # See if a warning is raised if we force it - with self.assertWarns("transformation isn't rectilinear"): + with pytest.warns(Warning, + match="transformation isn't rectilinear"): with xr.open_rasterio(tmp_file, parse_coordinates=True) as rioda: assert 'x' not in rioda.coords @@ -3024,7 +3021,7 @@ def test_chunks(self): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert 'open_rasterio' in actual.data.name # do some arithmetic @@ -3105,7 +3102,7 @@ def test_no_mftime(self): with mock.patch('os.path.getmtime', side_effect=OSError): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert_allclose(actual, expected) @network @@ -3118,10 +3115,10 @@ def test_http_url(self): # make sure chunking works with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) -class TestEncodingInvalid(TestCase): +class TestEncodingInvalid(object): def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'foo': 'bar'}) @@ -3130,12 +3127,12 @@ def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'chunking': (2, 1)}) encoding = _extract_nc4_variable_encoding(var) - self.assertEqual({}, encoding) + assert {} == encoding # regression test var = xr.Variable(('x',), [1, 2, 3], {}, {'shuffle': True}) encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True) - self.assertEqual({'shuffle': True}, encoding) + assert {'shuffle': True} == encoding def test_extract_h5nc_encoding(self): # not supported with h5netcdf (yet) @@ -3150,7 +3147,7 @@ class MiscObject: @requires_netCDF4 -class TestValidateAttrs(TestCase): +class TestValidateAttrs(object): def test_validating_attrs(self): def new_dataset(): return Dataset({'data': ('y', np.arange(10.0))}, @@ -3250,7 +3247,7 @@ def new_dataset_and_coord_attrs(): @requires_scipy_or_netCDF4 -class TestDataArrayToNetCDF(TestCase): +class TestDataArrayToNetCDF(object): def test_dataarray_to_netcdf_no_name(self): original_da = DataArray(np.arange(12).reshape((3, 4))) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 482a280b355..2004b1e660f 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -10,12 +10,12 @@ from xarray.core.pycompat import OrderedDict, iteritems from . import ( - InaccessibleArray, TestCase, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_dask) + InaccessibleArray, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask) from .test_dataset import create_test_data -class TestConcatDataset(TestCase): +class TestConcatDataset(object): def test_concat(self): # TODO: simplify and split this test case @@ -235,7 +235,7 @@ def test_concat_multiindex(self): assert isinstance(actual.x.to_index(), pd.MultiIndex) -class TestConcatDataArray(TestCase): +class TestConcatDataArray(object): def test_concat(self): ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), 'bar': (['x', 'y'], np.random.random((2, 3)))}, @@ -295,7 +295,7 @@ def test_concat_lazy(self): assert combined.dims == ('z', 'x', 'y') -class TestAutoCombine(TestCase): +class TestAutoCombine(object): @requires_dask # only for toolz def test_auto_combine(self): diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 5ed482ed2bd..a067d01a308 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -8,20 +8,20 @@ import pandas as pd import pytest -from xarray import (Dataset, Variable, SerializationWarning, coding, - conventions, open_dataset) +from xarray import ( + Dataset, SerializationWarning, Variable, coding, conventions, open_dataset) from xarray.backends.common import WritableCFDataStore from xarray.backends.memory import InMemoryDataStore from xarray.conventions import decode_cf from xarray.testing import assert_identical from . import ( - TestCase, assert_array_equal, raises_regex, requires_netCDF4, - requires_cftime_or_netCDF4, unittest, requires_dask) + assert_array_equal, raises_regex, requires_cftime_or_netCDF4, + requires_dask, requires_netCDF4) from .test_backends import CFEncodedDataTest -class TestBoolTypeArray(TestCase): +class TestBoolTypeArray(object): def test_booltype_array(self): x = np.array([1, 0, 1, 1, 0], dtype='i1') bx = conventions.BoolTypeArray(x) @@ -30,7 +30,7 @@ def test_booltype_array(self): dtype=np.bool)) -class TestNativeEndiannessArray(TestCase): +class TestNativeEndiannessArray(object): def test(self): x = np.arange(5, dtype='>i8') expected = np.arange(5, dtype='int64') @@ -69,7 +69,7 @@ def test_decode_cf_with_conflicting_fill_missing_value(): @requires_cftime_or_netCDF4 -class TestEncodeCFVariable(TestCase): +class TestEncodeCFVariable(object): def test_incompatible_attributes(self): invalid_vars = [ Variable(['t'], pd.date_range('2000-01-01', periods=3), @@ -134,7 +134,7 @@ def test_string_object_warning(self): @requires_cftime_or_netCDF4 -class TestDecodeCF(TestCase): +class TestDecodeCF(object): def test_dataset(self): original = Dataset({ 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), @@ -255,7 +255,7 @@ def encode_variable(self, var): @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): +class TestCFEncodedDataStore(CFEncodedDataTest): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() @@ -267,9 +267,10 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) + @pytest.mark.skip('cannot roundtrip coordinates yet for ' + 'CFEncodedInMemoryStore') def test_roundtrip_coordinates(self): - raise unittest.SkipTest('cannot roundtrip coordinates yet for ' - 'CFEncodedInMemoryStore') + pass def test_invalid_dataarray_names_raise(self): # only relevant for on-disk file formats diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 43fa35473ce..e56f751bef9 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,8 +1,8 @@ from __future__ import absolute_import, division, print_function import pickle -from textwrap import dedent from distutils.version import LooseVersion +from textwrap import dedent import numpy as np import pandas as pd @@ -15,15 +15,15 @@ from xarray.tests import mock from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_frame_equal, assert_identical, raises_regex) + assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, + assert_identical, raises_regex) dask = pytest.importorskip('dask') da = pytest.importorskip('dask.array') dd = pytest.importorskip('dask.dataframe') -class DaskTestCase(TestCase): +class DaskTestCase(object): def assertLazyAnd(self, expected, actual, test): with (dask.config.set(get=dask.get) @@ -57,6 +57,7 @@ def assertLazyAndIdentical(self, expected, actual): def assertLazyAndAllClose(self, expected, actual): self.assertLazyAnd(expected, actual, assert_allclose) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.RandomState(0).randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -249,6 +250,7 @@ def assertLazyAndAllClose(self, expected, actual): def assertLazyAndEqual(self, expected, actual): self.assertLazyAnd(expected, actual, assert_equal) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -581,7 +583,7 @@ def test_from_dask_variable(self): self.assertLazyAndIdentical(self.lazy_array, a) -class TestToDaskDataFrame(TestCase): +class TestToDaskDataFrame(object): def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index f8b288f4ab0..d15a0bb6081 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1,10 +1,9 @@ from __future__ import absolute_import, division, print_function import pickle +import warnings from copy import deepcopy -from distutils.version import LooseVersion from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -13,19 +12,20 @@ import xarray as xr from xarray import ( DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options) -from xarray.convert import from_cdms2 from xarray.coding.times import CFDatetimeCoder, _import_cftime -from xarray.core.common import full_like, ALL_DIMS +from xarray.convert import from_cdms2 +from xarray.core.common import ALL_DIMS, full_like from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( - ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal, + ReturnItem, assert_allclose, assert_array_equal, assert_equal, assert_identical, raises_regex, requires_bottleneck, requires_cftime, requires_dask, requires_iris, requires_np113, requires_scipy, - source_ndarray, unittest) + source_ndarray) -class TestDataArray(TestCase): - def setUp(self): +class TestDataArray(object): + @pytest.fixture(autouse=True) + def setup(self): self.attrs = {'attr1': 'value1', 'attr2': 2929} self.x = np.random.random((10, 20)) self.v = Variable(['x', 'y'], self.x) @@ -441,7 +441,7 @@ def test_getitem(self): assert_identical(self.ds['x'], x) assert_identical(self.ds['y'], y) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[x.values], I[x.variable], I[x], I[x, y], I[x.values > -1], I[x.variable > -1], I[x > -1], I[x > -1, y > -1]]: @@ -1002,7 +1002,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, assert da.dims[0] == renamed_dim da = da.rename({renamed_dim: 'x'}) assert_identical(da.variable, expected_da.variable) - self.assertVariableNotEqual(da['x'], expected_da['x']) + assert not da['x'].equals(expected_da['x']) test_sel(('a', 1, -1), 0) test_sel(('b', 2, -2), -1) @@ -2026,17 +2026,19 @@ def test_groupby_warning(self): with pytest.warns(FutureWarning): grouped.sum() - @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'), - reason="not to forget the behavior change") + # Currently disabled due to https://github.com/pydata/xarray/issues/2468 + # @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'), + # reason="not to forget the behavior change") + @pytest.mark.skip def test_groupby_sum_default(self): array = self.make_groupby_example_array() grouped = array.groupby('abc') expected_sum_all = Dataset( {'foo': Variable(['x', 'abc'], - np.array([self.x[:, :9].sum(axis=-1), - self.x[:, 10:].sum(axis=-1), - self.x[:, 9:10].sum(axis=-1)]).T), + np.array([self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1)]).T), 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] assert_allclose(expected_sum_all, grouped.sum()) @@ -2050,7 +2052,7 @@ def test_groupby_count(self): expected = DataArray([1, 1, 2], coords=[('cat', ['a', 'b', 'c'])]) assert_identical(actual, expected) - @unittest.skip('needs to be fixed for shortcut=False, keep_attrs=False') + @pytest.mark.skip('needs to be fixed for shortcut=False, keep_attrs=False') def test_groupby_reduce_attrs(self): array = self.make_groupby_example_array() array.attrs['foo'] = 'bar' @@ -2826,7 +2828,7 @@ def test_to_and_from_series(self): def test_series_categorical_index(self): # regression test for GH700 if not hasattr(pd, 'CategoricalIndex'): - raise unittest.SkipTest('requires pandas with CategoricalIndex') + pytest.skip('requires pandas with CategoricalIndex') s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list('aabbc'))) arr = DataArray(s) @@ -2968,7 +2970,7 @@ def test_to_and_from_cdms2_classic(self): actual = original.to_cdms2() assert_array_equal(actual.asma(), original) assert actual.id == original.name - self.assertItemsEqual(actual.getAxisIds(), original.dims) + assert tuple(actual.getAxisIds()) == original.dims for axis, coord in zip(actual.getAxisList(), expected_coords): assert axis.id == coord.name assert_array_equal(axis, coord.values) @@ -2982,8 +2984,8 @@ def test_to_and_from_cdms2_classic(self): assert_identical(original, roundtripped) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert original.dims == back.dims + assert original.coords.keys() == back.coords.keys() for coord_name in original.coords.keys(): assert_array_equal(original.coords[coord_name], back.coords[coord_name]) @@ -3004,15 +3006,15 @@ def test_to_and_from_cdms2_sgrid(self): coords=OrderedDict(x=x, y=y, lon=lon, lat=lat), name='sst') actual = original.to_cdms2() - self.assertItemsEqual(actual.getAxisIds(), original.dims) + assert tuple(actual.getAxisIds()) == original.dims assert_array_equal(original.coords['lon'], actual.getLongitude().asma()) assert_array_equal(original.coords['lat'], actual.getLatitude().asma()) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert original.dims == back.dims + assert set(original.coords.keys()) == set(back.coords.keys()) assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) @@ -3026,15 +3028,15 @@ def test_to_and_from_cdms2_ugrid(self): original = DataArray(np.arange(5), dims=['cell'], coords={'lon': lon, 'lat': lat, 'cell': cell}) actual = original.to_cdms2() - self.assertItemsEqual(actual.getAxisIds(), original.dims) + assert tuple(actual.getAxisIds()) == original.dims assert_array_equal(original.coords['lon'], actual.getLongitude().getValue()) assert_array_equal(original.coords['lat'], actual.getLatitude().getValue()) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert set(original.dims) == set(back.dims) + assert set(original.coords.keys()) == set(back.coords.keys()) assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) @@ -3127,17 +3129,17 @@ def test_coordinate_diff(self): actual = lon.diff('lon') assert_equal(expected, actual) - def test_shift(self): + @pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5]) + def test_shift(self, offset): arr = DataArray([1, 2, 3], dims='x') actual = arr.shift(x=1) expected = DataArray([np.nan, 1, 2], dims='x') assert_identical(expected, actual) arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])]) - for offset in [-5, -2, -1, 0, 1, 2, 5]: - expected = DataArray(arr.to_pandas().shift(offset)) - actual = arr.shift(x=offset) - assert_identical(expected, actual) + expected = DataArray(arr.to_pandas().shift(offset)) + actual = arr.shift(x=offset) + assert_identical(expected, actual) def test_roll_coords(self): arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2c964b81b98..9bee965392b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function +import sys +import warnings from copy import copy, deepcopy from io import StringIO from textwrap import dedent -import warnings -import sys import numpy as np import pandas as pd @@ -13,17 +13,17 @@ import xarray as xr from xarray import ( - DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends, - broadcast, open_dataset, set_options, ALL_DIMS) + ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align, + backends, broadcast, open_dataset, set_options) from xarray.core import indexing, npcompat, utils from xarray.core.common import full_like from xarray.core.pycompat import ( OrderedDict, integer_types, iteritems, unicode_type) from . import ( - InaccessibleArray, TestCase, UnexpectedDataAccess, assert_allclose, - assert_array_equal, assert_equal, assert_identical, has_cftime, - has_dask, raises_regex, requires_bottleneck, requires_dask, requires_scipy, + InaccessibleArray, UnexpectedDataAccess, assert_allclose, + assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, + raises_regex, requires_bottleneck, requires_dask, requires_scipy, source_ndarray) try: @@ -86,7 +86,7 @@ def lazy_inaccessible(k, v): k, v in iteritems(self._variables)) -class TestDataset(TestCase): +class TestDataset(object): def test_repr(self): data = create_test_data(seed=123) data.attrs['foo'] = 'bar' @@ -399,7 +399,7 @@ def test_constructor_with_coords(self): ds = Dataset({}, {'a': ('x', [1])}) assert not ds.data_vars - self.assertItemsEqual(ds.coords.keys(), ['a']) + assert list(ds.coords.keys()) == ['a'] mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], names=('level_1', 'level_2')) @@ -421,9 +421,9 @@ def test_properties(self): assert type(ds.dims.mapping.mapping) is dict # noqa with pytest.warns(FutureWarning): - self.assertItemsEqual(ds, list(ds.variables)) + assert list(ds) == list(ds.variables) with pytest.warns(FutureWarning): - self.assertItemsEqual(ds.keys(), list(ds.variables)) + assert list(ds.keys()) == list(ds.variables) assert 'aasldfjalskdfj' not in ds.variables assert 'dim1' in repr(ds.variables) with pytest.warns(FutureWarning): @@ -431,18 +431,18 @@ def test_properties(self): with pytest.warns(FutureWarning): assert bool(ds) - self.assertItemsEqual(ds.data_vars, ['var1', 'var2', 'var3']) - self.assertItemsEqual(ds.data_vars.keys(), ['var1', 'var2', 'var3']) + assert list(ds.data_vars) == ['var1', 'var2', 'var3'] + assert list(ds.data_vars.keys()) == ['var1', 'var2', 'var3'] assert 'var1' in ds.data_vars assert 'dim1' not in ds.data_vars assert 'numbers' not in ds.data_vars assert len(ds.data_vars) == 3 - self.assertItemsEqual(ds.indexes, ['dim2', 'dim3', 'time']) + assert set(ds.indexes) == {'dim2', 'dim3', 'time'} assert len(ds.indexes) == 3 assert 'dim2' in repr(ds.indexes) - self.assertItemsEqual(ds.coords, ['time', 'dim2', 'dim3', 'numbers']) + assert list(ds.coords) == ['time', 'dim2', 'dim3', 'numbers'] assert 'dim2' in ds.coords assert 'numbers' in ds.coords assert 'var1' not in ds.coords @@ -535,7 +535,7 @@ def test_coords_properties(self): assert 4 == len(data.coords) - self.assertItemsEqual(['x', 'y', 'a', 'b'], list(data.coords)) + assert ['x', 'y', 'a', 'b'] == list(data.coords) assert_identical(data.coords['x'].variable, data['x'].variable) assert_identical(data.coords['y'].variable, data['y'].variable) @@ -831,7 +831,7 @@ def test_isel(self): ret = data.isel(**slicers) # Verify that only the specified dimension was altered - self.assertItemsEqual(data.dims, ret.dims) + assert list(data.dims) == list(ret.dims) for d in data.dims: if d in slicers: assert ret.dims[d] == \ @@ -857,21 +857,21 @@ def test_isel(self): ret = data.isel(dim1=0) assert {'time': 20, 'dim2': 9, 'dim3': 10} == ret.dims - self.assertItemsEqual(data.data_vars, ret.data_vars) - self.assertItemsEqual(data.coords, ret.coords) - self.assertItemsEqual(data.indexes, ret.indexes) + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.indexes) == set(ret.indexes) ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) assert {'time': 2, 'dim2': 5, 'dim3': 10} == ret.dims - self.assertItemsEqual(data.data_vars, ret.data_vars) - self.assertItemsEqual(data.coords, ret.coords) - self.assertItemsEqual(data.indexes, ret.indexes) + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.indexes) == set(ret.indexes) ret = data.isel(time=0, dim1=0, dim2=slice(5)) - self.assertItemsEqual({'dim2': 5, 'dim3': 10}, ret.dims) - self.assertItemsEqual(data.data_vars, ret.data_vars) - self.assertItemsEqual(data.coords, ret.coords) - self.assertItemsEqual(data.indexes, list(ret.indexes) + ['time']) + assert {'dim2': 5, 'dim3': 10} == ret.dims + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.indexes) == set(list(ret.indexes) + ['time']) def test_isel_fancy(self): # isel with fancy indexing. @@ -1482,7 +1482,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, ds = ds.rename({renamed_dim: 'x'}) assert_identical(ds['var'].variable, expected_ds['var'].variable) - self.assertVariableNotEqual(ds['x'], expected_ds['x']) + assert not ds['x'].equals(expected_ds['x']) test_sel(('a', 1, -1), 0) test_sel(('b', 2, -2), -1) @@ -2546,12 +2546,11 @@ def test_setitem_multiindex_level(self): def test_delitem(self): data = create_test_data() all_items = set(data.variables) - self.assertItemsEqual(data.variables, all_items) + assert set(data.variables) == all_items del data['var1'] - self.assertItemsEqual(data.variables, all_items - set(['var1'])) + assert set(data.variables) == all_items - set(['var1']) del data['numbers'] - self.assertItemsEqual(data.variables, - all_items - set(['var1', 'numbers'])) + assert set(data.variables) == all_items - set(['var1', 'numbers']) assert 'numbers' not in data.coords def test_squeeze(self): @@ -3425,8 +3424,8 @@ def test_reduce(self): (['dim2', 'time'], ['dim1', 'dim3']), (('dim2', 'time'), ['dim1', 'dim3']), ((), ['dim1', 'dim2', 'dim3', 'time'])]: - actual = data.min(dim=reduct).dims - self.assertItemsEqual(actual, expected) + actual = list(data.min(dim=reduct).dims) + assert actual == expected assert_equal(data.mean(dim=[]), data) @@ -3480,7 +3479,7 @@ def test_reduce_cumsum_test_dims(self): ('time', ['dim1', 'dim2', 'dim3']) ]: actual = getattr(data, cumfunc)(dim=reduct).dims - self.assertItemsEqual(actual, expected) + assert list(actual) == expected def test_reduce_non_numeric(self): data1 = create_test_data(seed=44) @@ -3618,14 +3617,14 @@ def test_rank(self): ds = create_test_data(seed=1234) # only ds.var3 depends on dim3 z = ds.rank('dim3') - self.assertItemsEqual(['var3'], list(z.data_vars)) + assert ['var3'] == list(z.data_vars) # same as dataarray version x = z.var3 y = ds.var3.rank('dim3') assert_equal(x, y) # coordinates stick - self.assertItemsEqual(list(z.coords), list(ds.coords)) - self.assertItemsEqual(list(x.coords), list(y.coords)) + assert list(z.coords) == list(ds.coords) + assert list(x.coords) == list(y.coords) # invalid dim with raises_regex(ValueError, 'does not contain'): x.rank('invalid_dim') @@ -3948,10 +3947,10 @@ def test_roll_coords_none(self): def test_roll_multidim(self): # regression test for 2445 arr = xr.DataArray( - [[1, 2, 3],[4, 5, 6]], coords={'x': range(3), 'y': range(2)}, - dims=('y','x')) + [[1, 2, 3], [4, 5, 6]], coords={'x': range(3), 'y': range(2)}, + dims=('y', 'x')) actual = arr.roll(x=1, roll_coords=True) - expected = xr.DataArray([[3, 1, 2],[6, 4, 5]], + expected = xr.DataArray([[3, 1, 2], [6, 4, 5]], coords=[('y', [0, 1]), ('x', [2, 0, 1])]) assert_identical(expected, actual) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index aab5d305a82..5ea5b3d2a42 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1,16 +1,16 @@ from __future__ import absolute_import, division, print_function +import warnings from distutils.version import LooseVersion +from textwrap import dedent import numpy as np import pandas as pd import pytest -from textwrap import dedent from numpy import array, nan -import warnings from xarray import DataArray, Dataset, concat -from xarray.core import duck_array_ops, dtypes +from xarray.core import dtypes, duck_array_ops from xarray.core.duck_array_ops import ( array_notnull_equiv, concatenate, count, first, gradient, last, mean, rolling_window, stack, where) @@ -18,12 +18,12 @@ from xarray.testing import assert_allclose, assert_equal from . import ( - TestCase, assert_array_equal, has_dask, has_np113, raises_regex, - requires_dask) + assert_array_equal, has_dask, has_np113, raises_regex, requires_dask) -class TestOps(TestCase): +class TestOps(object): + @pytest.fixture(autouse=True) def setUp(self): self.x = array([[[nan, nan, 2., nan], [nan, 5., 6., nan], diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 24b710ae223..ffefa78aa34 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -4,7 +4,7 @@ import xarray as xr -from . import TestCase, raises_regex +from . import raises_regex try: import cPickle as pickle @@ -21,7 +21,7 @@ def __init__(self, xarray_obj): self.obj = xarray_obj -class TestAccessor(TestCase): +class TestAccessor(object): def test_register(self): @xr.register_dataset_accessor('demo') diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 8a1003f1ced..024c669bed9 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -7,10 +7,10 @@ from xarray.core import formatting from xarray.core.pycompat import PY3 -from . import TestCase, raises_regex +from . import raises_regex -class TestFormatting(TestCase): +class TestFormatting(object): def test_get_indexer_at_least_n_items(self): cases = [ @@ -45,7 +45,7 @@ def test_first_n_items(self): for n in [3, 10, 13, 100, 200]: actual = formatting.first_n_items(array, n) expected = array.flat[:n] - self.assertItemsEqual(expected, actual) + assert (expected == actual).all() with raises_regex(ValueError, 'at least one item'): formatting.first_n_items(array, 0) @@ -55,7 +55,7 @@ def test_last_n_items(self): for n in [3, 10, 13, 100, 200]: actual = formatting.last_n_items(array, n) expected = array.flat[-n:] - self.assertItemsEqual(expected, actual) + assert (expected == actual).all() with raises_regex(ValueError, 'at least one item'): formatting.first_n_items(array, 0) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 0d1045d35c0..701eefcb462 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -10,13 +10,12 @@ from xarray.core import indexing, nputils from xarray.core.pycompat import native_int_types -from . import ( - IndexerMaker, ReturnItem, TestCase, assert_array_equal, raises_regex) +from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex B = IndexerMaker(indexing.BasicIndexer) -class TestIndexers(TestCase): +class TestIndexers(object): def set_to_zero(self, x, i): x = x.copy() x[i] = 0 @@ -25,7 +24,7 @@ def set_to_zero(self, x, i): def test_expanded_indexer(self): x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[0, :, 10], I[..., 10], I[:5, ..., 0], I[..., 0, :], I[y], I[y, y], I[..., y, y], I[..., 0, 1, 2, 3, 4]]: @@ -133,7 +132,7 @@ def test_indexer(data, x, expected_pos, expected_idx=None): pd.MultiIndex.from_product([[1, 2], [-1, -2]])) -class TestLazyArray(TestCase): +class TestLazyArray(object): def test_slice_slice(self): I = ReturnItem() # noqa: E741 # allow ambiguous name for size in [100, 99]: @@ -248,7 +247,7 @@ def check_indexing(v_eager, v_lazy, indexers): check_indexing(v_eager, v_lazy, indexers) -class TestCopyOnWriteArray(TestCase): +class TestCopyOnWriteArray(object): def test_setitem(self): original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) @@ -272,7 +271,7 @@ def test_index_scalar(self): assert np.array(x[B[0]][B[()]]) == 'foo' -class TestMemoryCachedArray(TestCase): +class TestMemoryCachedArray(object): def test_wrapper(self): original = indexing.LazilyOuterIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) @@ -385,8 +384,9 @@ def test_vectorized_indexer(): np.arange(5, dtype=np.int64))) -class Test_vectorized_indexer(TestCase): - def setUp(self): +class Test_vectorized_indexer(object): + @pytest.fixture(autouse=True) + def setup(self): self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13)) self.indexers = [np.array([[0, 3, 2], ]), np.array([[0, 3, 3], [4, 6, 7]]), diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 4d89be8ce55..300c490cff6 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -6,11 +6,11 @@ import xarray as xr from xarray.core import merge -from . import TestCase, raises_regex +from . import raises_regex from .test_dataset import create_test_data -class TestMergeInternals(TestCase): +class TestMergeInternals(object): def test_broadcast_dimension_size(self): actual = merge.broadcast_dimension_size( [xr.Variable('x', [1]), xr.Variable('y', [2, 1])]) @@ -25,7 +25,7 @@ def test_broadcast_dimension_size(self): [xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2])]) -class TestMergeFunction(TestCase): +class TestMergeFunction(object): def test_merge_arrays(self): data = create_test_data() actual = xr.merge([data.var1, data.var2]) @@ -130,7 +130,7 @@ def test_merge_no_conflicts_broadcast(self): assert expected.identical(actual) -class TestMergeMethod(TestCase): +class TestMergeMethod(object): def test_merge(self): data = create_test_data() @@ -195,7 +195,7 @@ def test_merge_compat(self): with pytest.raises(xr.MergeError): ds1.merge(ds2, compat='identical') - with raises_regex(ValueError, 'compat=\S+ invalid'): + with raises_regex(ValueError, 'compat=.* invalid'): ds1.merge(ds2, compat='foobar') def test_merge_auto_align(self): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 98265149122..01303202c93 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -5,9 +5,9 @@ import numpy as np import pandas as pd -import xarray as xr import pytest +import xarray as xr import xarray.plot as xplt from xarray import DataArray from xarray.coding.times import _import_cftime @@ -17,9 +17,8 @@ import_seaborn, label_from_attrs) from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, - requires_matplotlib, requires_matplotlib2, requires_seaborn, - requires_cftime) + assert_array_equal, assert_equal, raises_regex, requires_cftime, + requires_matplotlib, requires_matplotlib2, requires_seaborn) # import mpl and change the backend before other mpl imports try: @@ -65,8 +64,10 @@ def easy_array(shape, start=0, stop=1): @requires_matplotlib -class PlotTestCase(TestCase): - def tearDown(self): +class PlotTestCase(object): + @pytest.fixture(autouse=True) + def setup(self): + yield # Remove all matplotlib figures plt.close('all') @@ -88,7 +89,8 @@ def contourf_called(self, plotmethod): class TestPlot(PlotTestCase): - def setUp(self): + @pytest.fixture(autouse=True) + def setup_array(self): self.darray = DataArray(easy_array((2, 3, 4))) def test_label_from_attrs(self): @@ -160,8 +162,8 @@ def test_2d_line_accepts_legend_kw(self): self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True) assert plt.gca().get_legend() # check whether legend title is set - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') def test_2d_line_accepts_x_kw(self): self.darray[:, :, 0].plot.line(x='dim_0') @@ -172,12 +174,12 @@ def test_2d_line_accepts_x_kw(self): def test_2d_line_accepts_hue_kw(self): self.darray[:, :, 0].plot.line(hue='dim_0') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_0' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_0') plt.cla() self.darray[:, :, 0].plot.line(hue='dim_1') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') def test_2d_before_squeeze(self): a = DataArray(easy_array((1, 5))) @@ -345,6 +347,7 @@ def test_convenient_facetgrid_4d(self): class TestPlot1D(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = [0, 1.1, 0, 2] self.darray = DataArray( @@ -357,7 +360,7 @@ def test_xlabel_is_index_name(self): def test_no_label_name_on_x_axis(self): self.darray.plot(y='period') - self.assertEqual('', plt.gca().get_xlabel()) + assert '' == plt.gca().get_xlabel() def test_no_label_name_on_y_axis(self): self.darray.plot() @@ -417,6 +420,7 @@ def test_slice_in_title(self): class TestPlotHistogram(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(easy_array((2, 3, 4))) @@ -452,7 +456,8 @@ def test_plot_nans(self): @requires_matplotlib -class TestDetermineCmapParams(TestCase): +class TestDetermineCmapParams(object): + @pytest.fixture(autouse=True) def setUp(self): self.data = np.linspace(0, 1, num=100) @@ -625,7 +630,8 @@ def test_divergentcontrol(self): @requires_matplotlib -class TestDiscreteColorMap(TestCase): +class TestDiscreteColorMap(object): + @pytest.fixture(autouse=True) def setUp(self): x = np.arange(start=0, stop=10, step=2) y = np.arange(start=9, stop=-7, step=-3) @@ -706,7 +712,7 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): assert primitive.norm.vmin == min(levels) -class Common2dMixin: +class Common2dMixin(object): """ Common tests for 2d plotting go here. @@ -714,6 +720,7 @@ class Common2dMixin: Should have the same name as the method. """ + @pytest.fixture(autouse=True) def setUp(self): da = DataArray(easy_array((10, 15), start=-1), dims=['y', 'x'], @@ -1145,18 +1152,18 @@ def _color_as_tuple(c): assert artist.cmap.colors[0] == 'k' artist = self.plotmethod(colors=['k', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (0.0, 0.0, 1.0)) artist = self.darray.plot.contour( levels=[-0.5, 0., 0.5, 1.], colors=['k', 'r', 'w', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (1.0, 0.0, 0.0) - assert _color_as_tuple(artist.cmap.colors[2]) == \ - (1.0, 1.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (1.0, 0.0, 0.0)) + assert (_color_as_tuple(artist.cmap.colors[2]) == + (1.0, 1.0, 1.0)) # the last color is now under "over" - assert _color_as_tuple(artist.cmap._rgba_over) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap._rgba_over) == + (0.0, 0.0, 1.0)) def test_cmap_and_color_both(self): with pytest.raises(ValueError): @@ -1352,6 +1359,7 @@ def test_origin_overrides_xyincrease(self): class TestFacetGrid(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = easy_array((10, 15, 3)) self.darray = DataArray( @@ -1581,6 +1589,7 @@ def test_facetgrid_polar(self): @pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetGrid4d(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): a = easy_array((10, 15, 3, 2)) darray = DataArray(a, dims=['y', 'x', 'col', 'row']) @@ -1609,6 +1618,7 @@ def test_default_labels(self): @pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetedLinePlots(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(np.random.randn(10, 6, 3, 4), dims=['hue', 'x', 'col', 'row'], @@ -1689,6 +1699,7 @@ def test_wrong_num_of_dimensions(self): class TestDatetimePlot(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): ''' Create a DataArray with a time-axis that contains datetime objects. diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index d550a85e8ce..083ec5ee72f 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -2,15 +2,17 @@ import os +import pytest + from xarray import DataArray, tutorial from xarray.core.pycompat import suppress -from . import TestCase, assert_identical, network +from . import assert_identical, network @network -class TestLoadDataset(TestCase): - +class TestLoadDataset(object): + @pytest.fixture(autouse=True) def setUp(self): self.testfile = 'tiny' self.testfilepath = os.path.expanduser(os.sep.join( diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 0c0e0f3f744..34f401dd243 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -5,8 +5,8 @@ import numpy as np import pandas as pd import pytest -import xarray as xr +import xarray as xr from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils from xarray.core.options import set_options @@ -15,12 +15,12 @@ from xarray.testing import assert_identical from . import ( - TestCase, assert_array_equal, has_cftime, has_cftime_or_netCDF4, - requires_dask, requires_cftime) + assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_cftime, + requires_dask) from .test_coding_times import _all_cftime_date_types -class TestAlias(TestCase): +class TestAlias(object): def test(self): def new_method(): pass @@ -98,7 +98,7 @@ def test_multiindex_from_product_levels_non_unique(): np.testing.assert_array_equal(result.levels[1], [1, 2]) -class TestArrayEquiv(TestCase): +class TestArrayEquiv(object): def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional # object arrays @@ -108,8 +108,9 @@ def test_0d(self): assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object)) -class TestDictionaries(TestCase): - def setUp(self): +class TestDictionaries(object): + @pytest.fixture(autouse=True) + def setup(self): self.x = {'a': 'A', 'b': 'B'} self.y = {'c': 'C', 'b': 'B'} self.z = {'a': 'Z'} @@ -176,7 +177,7 @@ def test_frozen(self): def test_sorted_keys_dict(self): x = {'a': 1, 'b': 2, 'c': 3} y = utils.SortedKeysDict(x) - self.assertItemsEqual(y, ['a', 'b', 'c']) + assert list(y) == ['a', 'b', 'c'] assert repr(utils.SortedKeysDict()) == \ "SortedKeysDict({})" @@ -191,7 +192,7 @@ def test_chain_map(self): m['x'] = 100 assert m['x'] == 100 assert m.maps[0]['x'] == 100 - self.assertItemsEqual(['x', 'y', 'z'], m) + assert set(m) == {'x', 'y', 'z'} def test_repr_object(): @@ -199,7 +200,7 @@ def test_repr_object(): assert repr(obj) == 'foo' -class Test_is_uniform_and_sorted(TestCase): +class Test_is_uniform_and_sorted(object): def test_sorted_uniform(self): assert utils.is_uniform_spaced(np.arange(5)) @@ -220,7 +221,7 @@ def test_relative_tolerance(self): assert utils.is_uniform_spaced([0, 0.97, 2], rtol=0.1) -class Test_hashable(TestCase): +class Test_hashable(object): def test_hashable(self): for v in [False, 1, (2, ), (3, 4), 'four']: diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 1263ac1df9e..52289a15d72 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,11 +1,11 @@ from __future__ import absolute_import, division, print_function +import warnings from copy import copy, deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -25,11 +25,11 @@ from xarray.tests import requires_bottleneck from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_dask, source_ndarray) + assert_allclose, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask, source_ndarray) -class VariableSubclassTestCases(object): +class VariableSubclassobjects(object): def test_properties(self): data = 0.5 * np.arange(10) v = self.cls(['time'], data, {'foo': 'bar'}) @@ -479,20 +479,20 @@ def test_concat_mixed_dtypes(self): assert_identical(expected, actual) assert actual.dtype == object - def test_copy(self): + @pytest.mark.parametrize('deep', [True, False]) + def test_copy(self, deep): v = self.cls('x', 0.5 * np.arange(10), {'foo': 'bar'}) - for deep in [True, False]: - w = v.copy(deep=deep) - assert type(v) is type(w) - assert_identical(v, w) - assert v.dtype == w.dtype - if self.cls is Variable: - if deep: - assert source_ndarray(v.values) is not \ - source_ndarray(w.values) - else: - assert source_ndarray(v.values) is \ - source_ndarray(w.values) + w = v.copy(deep=deep) + assert type(v) is type(w) + assert_identical(v, w) + assert v.dtype == w.dtype + if self.cls is Variable: + if deep: + assert (source_ndarray(v.values) is not + source_ndarray(w.values)) + else: + assert (source_ndarray(v.values) is + source_ndarray(w.values)) assert_identical(v, copy(v)) def test_copy_index(self): @@ -814,10 +814,11 @@ def test_rolling_window(self): v_loaded[0] = 1.0 -class TestVariable(TestCase, VariableSubclassTestCases): +class TestVariable(VariableSubclassobjects): cls = staticmethod(Variable) - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) def test_data_and_values(self): @@ -1651,7 +1652,7 @@ def assert_assigned_2d(array, key_x, key_y, values): @requires_dask -class TestVariableWithDask(TestCase, VariableSubclassTestCases): +class TestVariableWithDask(VariableSubclassobjects): cls = staticmethod(lambda *args: Variable(*args).chunk()) @pytest.mark.xfail @@ -1691,7 +1692,7 @@ def test_getitem_with_mask_nd_indexer(self): self.cls(('x', 'y'), [[0, -1], [-1, 2]])) -class TestIndexVariable(TestCase, VariableSubclassTestCases): +class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable) def test_init(self): @@ -1804,7 +1805,7 @@ def test_rolling_window(self): super(TestIndexVariable, self).test_rolling_window() -class TestAsCompatibleData(TestCase): +class TestAsCompatibleData(object): def test_unchanged_types(self): types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray) for t in types: @@ -1945,9 +1946,10 @@ def test_raise_no_warning_for_nan_in_binary_ops(): assert len(record) == 0 -class TestBackendIndexing(TestCase): +class TestBackendIndexing(object): """ Make sure all the array wrappers can be indexed. """ + @pytest.fixture(autouse=True) def setUp(self): self.d = np.random.random((10, 3)).astype(np.float64) From 515324062cf6f182d20c1aad210e8627b0b4013f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 7 Oct 2018 18:31:16 -0400 Subject: [PATCH 46/51] tests shoudn't need to pass for a PR (#2471) --- .github/PULL_REQUEST_TEMPLATE.md | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5e9aa06f507..d1c79953a9b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,3 @@ - [ ] Closes #xxxx (remove if there is no corresponding issue, which should only be the case for minor changes) - [ ] Tests added (for all bug fixes or enhancements) - - [ ] Tests passed (for all non-documentation changes) - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later) From 3d65f02de7c0328029dd6c580f42ebeb7381579f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 7 Oct 2018 18:39:14 -0400 Subject: [PATCH 47/51] isort (#2469) --- asv_bench/benchmarks/dataset_io.py | 2 +- asv_bench/benchmarks/unstacking.py | 1 + properties/test_encode_decode.py | 4 ++-- setup.py | 4 +--- versioneer.py | 10 ++++++---- xarray/backends/netCDF4_.py | 3 +-- xarray/backends/pseudonetcdf_.py | 11 ++++------- xarray/backends/rasterio_.py | 2 +- xarray/coding/cftime_offsets.py | 5 ++--- xarray/coding/cftimeindex.py | 1 + xarray/coding/strings.py | 4 ++-- xarray/coding/times.py | 2 +- xarray/conventions.py | 4 ++-- xarray/core/accessors.py | 2 +- xarray/core/combine.py | 2 +- xarray/core/common.py | 5 ++--- xarray/core/computation.py | 8 ++++---- xarray/core/dask_array_compat.py | 2 +- xarray/core/dask_array_ops.py | 4 ++-- xarray/core/dataset.py | 8 ++++---- xarray/core/formatting.py | 3 +-- xarray/core/missing.py | 7 +++---- xarray/core/nanops.py | 7 +++---- xarray/core/npcompat.py | 1 + xarray/core/resample.py | 2 +- xarray/plot/facetgrid.py | 1 + xarray/plot/utils.py | 1 - xarray/tests/test_cftime_offsets.py | 12 +++++------- xarray/tests/test_cftimeindex.py | 11 +++++------ xarray/tests/test_coding_strings.py | 8 ++++---- xarray/tests/test_coding_times.py | 10 +++++----- xarray/tests/test_computation.py | 2 +- xarray/tests/test_groupby.py | 3 ++- xarray/tests/test_interp.py | 8 ++++---- xarray/tests/test_ufuncs.py | 6 +++--- 35 files changed, 79 insertions(+), 87 deletions(-) diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 54ed9ac9fa2..0b918e58eab 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -5,7 +5,7 @@ import xarray as xr -from . import randn, randint, requires_dask +from . import randint, randn, requires_dask try: import dask diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index aa304d4eb40..54436b422e9 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import numpy as np + import xarray as xr from . import requires_dask diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 7b3e75fbf0c..13f63f259cf 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -6,9 +6,9 @@ """ from __future__ import absolute_import, division, print_function -from hypothesis import given, settings -import hypothesis.strategies as st import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given, settings import xarray as xr diff --git a/setup.py b/setup.py index 68798bdf219..a7519bac6da 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,8 @@ #!/usr/bin/env python import sys -from setuptools import find_packages, setup - import versioneer - +from setuptools import find_packages, setup DISTNAME = 'xarray' LICENSE = 'Apache' diff --git a/versioneer.py b/versioneer.py index 64fea1c8927..dffd66b69a6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -277,10 +277,7 @@ """ from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser + import errno import json import os @@ -288,6 +285,11 @@ import subprocess import sys +try: + import configparser +except ImportError: + import ConfigParser as configparser + class VersioneerConfig: """Container for Versioneer configuration parameters.""" diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 5c6d82fd126..aa19633020b 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -10,8 +10,7 @@ from .. import Variable, coding from ..coding.variables import pop_to from ..core import indexing -from ..core.pycompat import ( - PY3, OrderedDict, basestring, iteritems, suppress) +from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d946c6fa927..3d846916740 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,17 +1,14 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import functools import numpy as np from .. import Variable -from ..core.pycompat import OrderedDict -from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing - -from .common import AbstractDataStore, DataStorePickleMixin, BackendArray +from ..core.pycompat import OrderedDict +from ..core.utils import Frozen, FrozenOrderedDict +from .common import AbstractDataStore, BackendArray, DataStorePickleMixin class PncArrayWrapper(BackendArray): diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 5221cf0e913..9cd5a889abc 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -1,7 +1,7 @@ import os +import warnings from collections import OrderedDict from distutils.version import LooseVersion -import warnings import numpy as np diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 3fbb44f4ed3..83e8c7a7e4b 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -41,15 +41,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import re - from datetime import timedelta from functools import partial import numpy as np -from .cftimeindex import _parse_iso8601_with_reso, CFTimeIndex -from .times import format_cftime_datetime from ..core.pycompat import basestring +from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso +from .times import format_cftime_datetime def get_date_type(calendar): diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 75a1fc9bd1a..dea896c199a 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -40,6 +40,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from __future__ import absolute_import + import re from datetime import timedelta diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 87b17d9175e..3502fd773d7 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -9,8 +9,8 @@ from ..core.pycompat import bytes_type, dask_array_type, unicode_type from ..core.variable import Variable from .variables import ( - VariableCoder, lazy_elemwise_func, pop_to, - safe_setitem, unpack_for_decoding, unpack_for_encoding) + VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, + unpack_for_decoding, unpack_for_encoding) def create_vlen_dtype(element_type): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 6edbedce54c..dff7e75bdcf 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -9,8 +9,8 @@ import numpy as np import pandas as pd -from ..core.common import contains_cftime_datetimes from ..core import indexing +from ..core.common import contains_cftime_datetimes from ..core.formatting import first_n_items, format_timestamp, last_item from ..core.options import OPTIONS from ..core.pycompat import PY3 diff --git a/xarray/conventions.py b/xarray/conventions.py index 67dcb8d6d4e..f60ee6b2c15 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,11 +6,11 @@ import numpy as np import pandas as pd -from .coding import times, strings, variables +from .coding import strings, times, variables from .coding.variables import SerializationWarning from .core import duck_array_ops, indexing from .core.pycompat import ( - OrderedDict, basestring, bytes_type, iteritems, dask_array_type, + OrderedDict, basestring, bytes_type, dask_array_type, iteritems, unicode_type) from .core.variable import IndexVariable, Variable, as_variable diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 81af0532d93..72791ed73ec 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from .common import is_np_datetime_like, _contains_datetime_like_objects +from .common import _contains_datetime_like_objects, is_np_datetime_like from .pycompat import dask_array_type diff --git a/xarray/core/combine.py b/xarray/core/combine.py index f0cc025dc7e..6853939c02d 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -8,8 +8,8 @@ from .alignment import align from .merge import merge from .pycompat import OrderedDict, basestring, iteritems -from .variable import concat as concat_vars from .variable import IndexVariable, Variable, as_variable +from .variable import concat as concat_vars def concat(objs, dim=None, data_vars='all', coords='different', diff --git a/xarray/core/common.py b/xarray/core/common.py index 41e4fec2982..c74b1fa080b 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -7,11 +7,10 @@ import numpy as np import pandas as pd -from . import duck_array_ops, dtypes, formatting, ops +from . import dtypes, duck_array_ops, formatting, ops from .arithmetic import SupportsArithmetic from .pycompat import OrderedDict, basestring, dask_array_type, suppress -from .utils import either_dict_or_kwargs, Frozen, SortedKeysDict, ReprObject - +from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs # Used as a sentinel value to indicate a all dimensions ALL_DIMS = ReprObject('') diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bdba72cb48a..7998cc4f72f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -2,19 +2,19 @@ Functions for applying functions that act on arrays to xarray's labeled data. """ from __future__ import absolute_import, division, print_function -from distutils.version import LooseVersion + import functools import itertools import operator from collections import Counter +from distutils.version import LooseVersion import numpy as np -from . import duck_array_ops -from . import utils +from . import duck_array_ops, utils from .alignment import deep_align from .merge import expand_and_merge_variables -from .pycompat import OrderedDict, dask_array_type, basestring +from .pycompat import OrderedDict, basestring, dask_array_type from .utils import is_dict_like _DEFAULT_FROZEN_SET = frozenset() diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 2196dba7f86..6b53dcffe6e 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -2,9 +2,9 @@ from distutils.version import LooseVersion +import dask.array as da import numpy as np from dask import __version__ as dask_version -import dask.array as da try: from dask.array import isin diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 423a65aa3c2..25c572edd54 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, division, print_function + from distutils.version import LooseVersion import numpy as np -from . import nputils -from . import dtypes +from . import dtypes, nputils try: import dask diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5e787c1587b..4ade15825c6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -16,6 +16,7 @@ alignment, computation, duck_array_ops, formatting, groupby, indexing, ops, resample, rolling, utils) from .. import conventions +from ..coding.cftimeindex import _parse_array_of_cftime_strings from .alignment import align from .common import ( ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, @@ -31,12 +32,11 @@ from .pycompat import ( OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) from .utils import ( - Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, - ensure_us_time_resolution, hashable, maybe_wrap_array, datetime_to_numeric) + Frozen, SortedKeysDict, datetime_to_numeric, decode_numpy_dict_values, + either_dict_or_kwargs, ensure_us_time_resolution, hashable, + maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables -from ..coding.cftimeindex import _parse_array_of_cftime_strings - # list of attributes of pd.DatetimeIndex that are ndarrays of time info _DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond', 'nanosecond', 'date', diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 042c8c5324d..5dd3cf06025 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -15,8 +15,7 @@ from .options import OPTIONS from .pycompat import ( - PY2, bytes_type, dask_array_type, unicode_type, zip_longest, -) + PY2, bytes_type, dask_array_type, unicode_type, zip_longest) try: from pandas.errors import OutOfBoundsDatetime diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 0b560c277ae..3f4e0fc3ac9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,20 +1,19 @@ from __future__ import absolute_import, division, print_function +import warnings from collections import Iterable from functools import partial -import warnings - import numpy as np import pandas as pd from . import rolling from .common import _contains_datetime_like_objects from .computation import apply_ufunc +from .duck_array_ops import dask_array_type from .pycompat import iteritems -from .utils import is_scalar, OrderedSet, datetime_to_numeric +from .utils import OrderedSet, datetime_to_numeric, is_scalar from .variable import Variable, broadcast_variables -from .duck_array_ops import dask_array_type class BaseInterpolator(object): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 9549c8e77b9..4d3f03c899e 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -2,11 +2,10 @@ import numpy as np -from . import dtypes +from . import dtypes, nputils +from .duck_array_ops import ( + _dask_or_eager_func, count, fillna, isnull, where_method) from .pycompat import dask_array_type -from . duck_array_ops import (count, isnull, fillna, where_method, - _dask_or_eager_func) -from . import nputils try: import dask.array as dask_array diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 22dff44acf8..efa68c8bad5 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function from distutils.version import LooseVersion + import numpy as np try: diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 25c149c51af..bd84e04487e 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function from . import ops -from .groupby import DataArrayGroupBy, DatasetGroupBy, DEFAULT_DIMS +from .groupby import DEFAULT_DIMS, DataArrayGroupBy, DatasetGroupBy from .pycompat import OrderedDict, dask_array_type RESAMPLE_DIM = '__resample_dim__' diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 79a3993e23b..32a954a3fcd 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -5,6 +5,7 @@ import warnings import numpy as np + from ..core.formatting import format_item from ..core.pycompat import getargspec from .utils import ( diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 455d27c3987..a284c186937 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -8,7 +8,6 @@ from ..core.options import OPTIONS from ..core.pycompat import basestring from ..core.utils import is_scalar -from ..core.options import OPTIONS ROBUST_PERCENTILE = 2.0 diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 6d7990689ed..7acd764cab3 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -1,15 +1,13 @@ -import pytest - from itertools import product import numpy as np +import pytest -from xarray.coding.cftime_offsets import ( - BaseCFTimeOffset, YearBegin, YearEnd, MonthBegin, MonthEnd, - Day, Hour, Minute, Second, _days_in_month, - to_offset, get_date_type, _MONTH_ABBREVIATIONS, to_cftime_datetime, - cftime_range) from xarray import CFTimeIndex +from xarray.coding.cftime_offsets import ( + _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin, + MonthEnd, Second, YearBegin, YearEnd, _days_in_month, cftime_range, + get_date_type, to_cftime_datetime, to_offset) cftime = pytest.importorskip('cftime') diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 33bf2cbce0d..d1726ab3313 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1,16 +1,15 @@ from __future__ import absolute_import -import pytest +from datetime import timedelta import numpy as np import pandas as pd -import xarray as xr +import pytest -from datetime import timedelta +import xarray as xr from xarray.coding.cftimeindex import ( - parse_iso8601, CFTimeIndex, assert_all_valid_date_type, - _parsed_string_to_bounds, _parse_iso8601_with_reso, - _parse_array_of_cftime_strings) + CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso, + _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601) from xarray.tests import assert_array_equal, assert_identical from . import has_cftime, has_cftime_or_netCDF4, requires_cftime diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 53d028e164b..ca138ca8362 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -5,13 +5,13 @@ import pytest from xarray import Variable -from xarray.core.pycompat import bytes_type, unicode_type, suppress from xarray.coding import strings from xarray.core import indexing +from xarray.core.pycompat import bytes_type, suppress, unicode_type -from . import (IndexerMaker, assert_array_equal, assert_identical, - raises_regex, requires_dask) - +from . import ( + IndexerMaker, assert_array_equal, assert_identical, raises_regex, + requires_dask) with suppress(ImportError): import dask.array as da diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 7d3a4930b44..10a1a956b27 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1,20 +1,20 @@ from __future__ import absolute_import, division, print_function -from itertools import product import warnings +from itertools import product import numpy as np import pandas as pd import pytest -from xarray import Variable, coding, set_options, DataArray, decode_cf +from xarray import DataArray, Variable, coding, decode_cf, set_options from xarray.coding.times import _import_cftime from xarray.coding.variables import SerializationWarning from xarray.core.common import contains_cftime_datetimes -from . import (assert_array_equal, has_cftime_or_netCDF4, - requires_cftime_or_netCDF4, has_cftime, has_dask) - +from . import ( + assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask, + requires_cftime_or_netCDF4) _NON_STANDARD_CALENDARS_SET = {'noleap', '365_day', '360_day', 'julian', 'all_leap', '366_day'} diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index ca8e4e59737..1003c531018 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -15,7 +15,7 @@ join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, unified_dim_sizes) -from . import raises_regex, requires_dask, has_dask +from . import has_dask, raises_regex, requires_dask def assert_identical(a, b): diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6dd14f5d6ad..8ace55be66b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -5,9 +5,10 @@ import pytest import xarray as xr -from . import assert_identical from xarray.core.groupby import _consolidate_slices +from . import assert_identical + def test_consolidate_slices(): diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 0778a1ff128..624879cce1f 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -5,12 +5,12 @@ import pytest import xarray as xr -from xarray.tests import (assert_allclose, assert_equal, requires_cftime, - requires_scipy) -from . import has_dask, has_scipy -from .test_dataset import create_test_data +from xarray.tests import ( + assert_allclose, assert_equal, requires_cftime, requires_scipy) +from . import has_dask, has_scipy from ..coding.cftimeindex import _parse_array_of_cftime_strings +from .test_dataset import create_test_data try: import scipy diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 195bb36e36e..6941efb1c6e 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -8,9 +8,9 @@ import xarray as xr import xarray.ufuncs as xu -from . import ( - assert_array_equal, assert_identical as assert_identical_, mock, - raises_regex, requires_np113) +from . import assert_array_equal +from . import assert_identical as assert_identical_ +from . import mock, raises_regex, requires_np113 def assert_identical(a, b): From cf1e6c73d0366124485c1d767b89ac1cc301705b Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Mon, 8 Oct 2018 00:40:07 +0200 Subject: [PATCH 48/51] pep8speaks (#2462) * yml for pep8speaks * Updated yml * Intensionally added a badly styled scripts * experimentally removed yml * .yml file taken from pandas-dev/pandas/.pep8speaks.yml * Undo inteded pep8 violation --- .pep8speaks.yml | 11 +++++++++++ .stickler.yml | 11 ----------- 2 files changed, 11 insertions(+), 11 deletions(-) create mode 100644 .pep8speaks.yml delete mode 100644 .stickler.yml diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 00000000000..cd610907007 --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,11 @@ +# File : .pep8speaks.yml + +scanner: + diff_only: True # If True, errors caused by only the patch are shown + +pycodestyle: + max-line-length: 79 + ignore: # Errors and warnings to ignore + - E402, # module level import not at top of file + - E731, # do not assign a lambda expression, use a def + - W503 # line break before binary operator diff --git a/.stickler.yml b/.stickler.yml deleted file mode 100644 index 79d8b7fb717..00000000000 --- a/.stickler.yml +++ /dev/null @@ -1,11 +0,0 @@ -linters: - flake8: - max-line-length: 79 - fixer: false - ignore: I002 - # stickler doesn't support 'exclude' for flake8 properly, so we disable it - # below with files.ignore: - # https://github.com/markstory/lint-review/issues/184 -files: - ignore: - - doc/**/*.py From 5f09deb96ac2041e2a2e5affcc8e693bea9a5d73 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 8 Oct 2018 14:23:34 +0900 Subject: [PATCH 49/51] Properly support user-provided norm. (#2443) * Properly support user-provided norm. Fixes #2381 * remove top level mpl import. * More accurate error message. * whats-new fixes. --- doc/whats-new.rst | 13 ++++++---- xarray/plot/plot.py | 12 ++++++---- xarray/plot/utils.py | 33 +++++++++++++++++++++++--- xarray/tests/test_plot.py | 50 ++++++++++++++++++++++++++++++++------- 4 files changed, 87 insertions(+), 21 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e9c223ff801..85e9f2313d6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,15 +40,15 @@ Breaking changes Documentation ~~~~~~~~~~~~~ + Enhancements ~~~~~~~~~~~~ - Added support for Python 3.7. (:issue:`2271`). By `Joe Hamman `_. - - Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a - CFTimeIndex by a specified frequency. (:issue:`2244`). By `Spencer Clark - `_. + CFTimeIndex by a specified frequency. (:issue:`2244`). + By `Spencer Clark `_. - Added support for using ``cftime.datetime`` coordinates with :py:meth:`~xarray.DataArray.differentiate`, :py:meth:`~xarray.Dataset.differentiate`, @@ -60,11 +60,14 @@ Bug fixes ~~~~~~~~~ - Addition and subtraction operators used with a CFTimeIndex now preserve the - index's type. (:issue:`2244`). By `Spencer Clark `_. + index's type. (:issue:`2244`). + By `Spencer Clark `_. - ``xarray.DataArray.roll`` correctly handles multidimensional arrays. (:issue:`2445`) By `Keisuke Fujii `_. - +- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override + the norm's ``vmin`` and ``vmax``. (:issue:`2381`) + By `Deepak Cherian `_. - ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. (:issue:`2240`) By `Keisuke Fujii `_. diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 3f9f1090c70..b44ae7b3856 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -562,6 +562,9 @@ def _plot2d(plotfunc): Adds colorbar to axis add_labels : Boolean, optional Use xarray metadata to label axes + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. vmin, vmax : floats, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, @@ -630,7 +633,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, levels=None, infer_intervals=None, colors=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, **kwargs): + xlim=None, ylim=None, norm=None, **kwargs): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. @@ -727,6 +730,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, 'extend': extend, 'levels': levels, 'filled': plotfunc.__name__ != 'contour', + 'norm': norm, } cmap_params = _determine_cmap_params(**cmap_kwargs) @@ -746,9 +750,6 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, if 'pcolormesh' == plotfunc.__name__: kwargs['infer_intervals'] = infer_intervals - # This allows the user to pass in a custom norm coming via kwargs - kwargs.setdefault('norm', cmap_params['norm']) - if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring): # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available " @@ -758,6 +759,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], + norm=cmap_params['norm'], **kwargs) # Label the plot with metadata @@ -809,7 +811,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, levels=None, infer_intervals=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, **kwargs): + xlim=None, ylim=None, norm=None, **kwargs): """ The method should have the same signature as the function. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a284c186937..be38a6d7a4c 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -172,6 +172,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # vlim might be computed below vlim = None + # save state; needed later + vmin_was_none = vmin is None + vmax_was_none = vmax is None + if vmin is None: if robust: vmin = np.percentile(calc_data, ROBUST_PERCENTILE) @@ -204,6 +208,28 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, vmin += center vmax += center + # now check norm and harmonize with vmin, vmax + if norm is not None: + if norm.vmin is None: + norm.vmin = vmin + else: + if not vmin_was_none and vmin != norm.vmin: + raise ValueError('Cannot supply vmin and a norm' + + ' with a different vmin.') + vmin = norm.vmin + + if norm.vmax is None: + norm.vmax = vmax + else: + if not vmax_was_none and vmax != norm.vmax: + raise ValueError('Cannot supply vmax and a norm' + + ' with a different vmax.') + vmax = norm.vmax + + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries + # Choose default colormaps if not provided if cmap is None: if divergent: @@ -212,7 +238,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, cmap = OPTIONS['cmap_sequential'] # Handle discrete levels - if levels is not None: + if levels is not None and norm is None: if is_scalar(levels): if user_minmax: levels = np.linspace(vmin, vmax, levels) @@ -227,8 +253,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, if extend is None: extend = _determine_extend(calc_data, vmin, vmax) - if levels is not None: - cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled) + if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm): + cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) + norm = newnorm if norm is None else norm return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 01303202c93..53f6077ee4f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -628,6 +628,26 @@ def test_divergentcontrol(self): assert cmap_params['vmax'] == 0.6 assert cmap_params['cmap'] == "viridis" + def test_norm_sets_vmin_vmax(self): + vmin = self.data.min() + vmax = self.data.max() + + for norm, extend in zip([mpl.colors.LogNorm(), + mpl.colors.LogNorm(vmin + 1, vmax - 1), + mpl.colors.LogNorm(None, vmax - 1), + mpl.colors.LogNorm(vmin + 1, None)], + ['neither', 'both', 'max', 'min']): + + test_min = vmin if norm.vmin is None else norm.vmin + test_max = vmax if norm.vmax is None else norm.vmax + + cmap_params = _determine_cmap_params(self.data, norm=norm) + + assert cmap_params['vmin'] == test_min + assert cmap_params['vmax'] == test_max + assert cmap_params['extend'] == extend + assert cmap_params['norm'] == norm + @requires_matplotlib class TestDiscreteColorMap(object): @@ -665,10 +685,10 @@ def test_build_discrete_cmap(self): @pytest.mark.slow def test_discrete_colormap_list_of_levels(self): - for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both', - [2, 5, 10, 11]), - ('neither', [0, 5, 10, 15]), ('min', - [2, 5, 10, 15])]: + for extend, levels in [('max', [-1, 2, 4, 8, 10]), + ('both', [2, 5, 10, 11]), + ('neither', [0, 5, 10, 15]), + ('min', [2, 5, 10, 15])]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)(levels=levels) assert_array_equal(levels, primitive.norm.boundaries) @@ -682,10 +702,10 @@ def test_discrete_colormap_list_of_levels(self): @pytest.mark.slow def test_discrete_colormap_int_levels(self): - for extend, levels, vmin, vmax in [('neither', 7, None, - None), ('neither', 7, None, 20), - ('both', 7, 4, 8), ('min', 10, 4, - 15)]: + for extend, levels, vmin, vmax in [('neither', 7, None, None), + ('neither', 7, None, 20), + ('both', 7, 4, 8), + ('min', 10, 4, 15)]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)( levels=levels, vmin=vmin, vmax=vmax) @@ -711,6 +731,11 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) + def test_discrete_colormap_provided_boundary_norm(self): + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + np.testing.assert_allclose(primitive.levels, norm.boundaries) + class Common2dMixin(object): """ @@ -1085,6 +1110,15 @@ def test_cmap_and_color_both(self): with pytest.raises(ValueError): self.plotmethod(colors='k', cmap='RdBu') + def test_colormap_error_norm_and_vmin_vmax(self): + norm = mpl.colors.LogNorm(0.1, 1e1) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmin=2) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmax=2) + @pytest.mark.slow class TestContourf(Common2dMixin, PlotTestCase): From 5b4d160d9a714c2cc83ff5788e2d73af92129713 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Oct 2018 11:17:01 -0700 Subject: [PATCH 50/51] Fix indexing error for data loaded with open_rasterio (#2456) xref GH2454 --- doc/whats-new.rst | 6 +++++- xarray/backends/rasterio_.py | 2 +- xarray/tests/test_backends.py | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 85e9f2313d6..7cdb1685f5f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -76,6 +76,10 @@ Bug fixes By `Deepak Cherian `_. +- Fix a bug that caused some indexing operations on arrays opened with + ``open_rasterio`` to error (:issue:`2454`). + By `Stephan Hoyer `_. + .. _whats-new.0.10.9: v0.10.9 (21 September 2018) @@ -86,7 +90,7 @@ This minor release contains a number of backwards compatible enhancements. Announcements of note: - Xarray is now a NumFOCUS fiscally sponsored project! Read - `the anouncment `_ + `the anouncement `_ for more details. - We have a new :doc:`roadmap` that outlines our future development plans. diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 9cd5a889abc..44cca9aaaf8 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -95,7 +95,7 @@ def _get_indexer(self, key): if isinstance(key[1], np.ndarray) and isinstance(key[2], np.ndarray): # do outer-style indexing - np_inds[1:] = np.ix_(*np_inds[1:]) + np_inds[-2:] = np.ix_(*np_inds[-2:]) return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a2e1cb4c0fa..0d97ed70fa3 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2925,6 +2925,10 @@ def test_indexing(self): assert_allclose(expected.isel(**ind), actual.isel(**ind)) assert not actual.variable._in_memory + ind = {'band': 0, 'x': np.array([0, 0]), 'y': np.array([1, 1, 1])} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + # minus-stepped slice ind = {'band': np.array([2, 1, 0]), 'x': slice(-1, None, -1), 'y': 0} From 289b377129b18e7dc6da8336e958a85be868acbe Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Oct 2018 21:13:41 -0700 Subject: [PATCH 51/51] xarray.backends refactor (#2261) * WIP: xarray.backends.file_manager for managing file objects. This is intended to replace both PickleByReconstructionWrapper and DataStorePickleMixin with something more compartmentalized. xref GH2121 * Switch rasterio to use FileManager * lint fixes * WIP: rewrite FileManager to always use an LRUCache * Test coverage * Don't use move_to_end * minor clarification * Switch FileManager.acquire() to a method * Python 2 compat * Update xarray.set_options() to add file_cache_maxsize and validation * Add assert for FILE_CACHE.maxsize * More docstring for FileManager * Add accidentally omited tests for LRUCache * Adapt scipy backend to use FileManager * Stickler fix * Fix failure on Python 2.7 * Finish adjusting backends to use FileManager * Fix bad import * WIP on distributed * More WIP * Fix distributed write tests * Fixes * Minor fixup * whats new * More refactoring: remove state from backends entirely * Cleanup * Fix failing in-memory datastore tests * Fix inaccessible datastore * fix autoclose warnings * Fix PyNIO failures * No longer disable HDF5 file locking We longer need to explicitly HDF5_USE_FILE_LOCKING='FALSE' because we properly close open files. * whats new and default file cache size * Whats new tweak * Refactor default lock logic to backend classes * Rename get_resource_lock -> get_write_lock * Don't acquire unnecessary locks in __getitem__ * Fix bad merge * Fix import * Remove unreachable code --- asv_bench/asv.conf.json | 1 + asv_bench/benchmarks/dataset_io.py | 41 ++++ doc/api.rst | 3 + doc/whats-new.rst | 19 +- xarray/backends/__init__.py | 4 + xarray/backends/api.py | 250 ++++++++++----------- xarray/backends/common.py | 215 +----------------- xarray/backends/file_manager.py | 206 +++++++++++++++++ xarray/backends/h5netcdf_.py | 169 +++++++------- xarray/backends/locks.py | 191 ++++++++++++++++ xarray/backends/lru_cache.py | 91 ++++++++ xarray/backends/memory.py | 3 +- xarray/backends/netCDF4_.py | 231 ++++++++++--------- xarray/backends/pseudonetcdf_.py | 79 ++++--- xarray/backends/pynio_.py | 53 ++--- xarray/backends/rasterio_.py | 94 ++++---- xarray/backends/scipy_.py | 129 +++++------ xarray/backends/zarr.py | 21 +- xarray/core/dataset.py | 25 +-- xarray/core/options.py | 71 ++++-- xarray/core/pycompat.py | 9 +- xarray/tests/test_backends.py | 209 +++++------------ xarray/tests/test_backends_file_manager.py | 114 ++++++++++ xarray/tests/test_backends_locks.py | 13 ++ xarray/tests/test_backends_lru_cache.py | 91 ++++++++ xarray/tests/test_dataset.py | 4 +- xarray/tests/test_distributed.py | 110 +++++---- xarray/tests/test_options.py | 33 +++ 28 files changed, 1496 insertions(+), 983 deletions(-) create mode 100644 xarray/backends/file_manager.py create mode 100644 xarray/backends/locks.py create mode 100644 xarray/backends/lru_cache.py create mode 100644 xarray/tests/test_backends_file_manager.py create mode 100644 xarray/tests/test_backends_locks.py create mode 100644 xarray/tests/test_backends_lru_cache.py diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index b5953436387..e3933b400e6 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -64,6 +64,7 @@ "scipy": [""], "bottleneck": ["", null], "dask": [""], + "distributed": [""], }, diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 0b918e58eab..da18d541a16 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import os + import numpy as np import pandas as pd @@ -14,6 +16,9 @@ pass +os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' + + class IOSingleNetCDF(object): """ A few examples that benchmark reading/writing a single netCDF file with @@ -405,3 +410,39 @@ def time_open_dataset_scipy_with_time_chunks(self): with dask.set_options(get=dask.multiprocessing.get): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks) + + +def create_delayed_write(): + import dask.array as da + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({'vals': (['a'], vals)}) + return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + + +class IOWriteNetCDFDask(object): + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + requires_dask() + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed(object): + def setup(self): + try: + import distributed + except ImportError: + raise NotImplementedError + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() diff --git a/doc/api.rst b/doc/api.rst index d204fab3539..662ef567710 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -624,3 +624,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7cdb1685f5f..d0fec7b0778 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,14 +33,27 @@ v0.11.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ +- Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 128 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. + + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ + +Documentation +~~~~~~~~~~~~~ - Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` without dimension argument will change in the next release. Now we warn a FutureWarning. By `Keisuke Fujii `_. -Documentation -~~~~~~~~~~~~~ - Enhancements ~~~~~~~~~~~~ diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 47a2011a3af..a2f0d79a6d1 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,6 +4,7 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .common import AbstractDataStore +from .file_manager import FileManager, CachingFileManager, DummyFileManager from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore from .pydap_ import PydapDataStore @@ -15,6 +16,9 @@ __all__ = [ 'AbstractDataStore', + 'FileManager', + 'CachingFileManager', + 'DummyFileManager', 'InMemoryDataStore', 'NetCDF4DataStore', 'PydapDataStore', diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2bf13011bd1..65112527045 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -4,6 +4,7 @@ from glob import glob from io import BytesIO from numbers import Number +import warnings import numpy as np @@ -12,8 +13,9 @@ from ..core.combine import auto_combine from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri -from .common import ( - HDF5_LOCK, ArrayWriter, CombinedLock, _get_scheduler, _get_scheduler_lock) +from .common import ArrayWriter +from .locks import _get_scheduler + DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -52,27 +54,6 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_lock(filename, engine): - if filename.endswith('.gz'): - lock = False - else: - if engine is None: - engine = _get_default_engine(filename, allow_remote=True) - - if engine == 'netcdf4': - if is_remote_uri(filename): - lock = False - else: - # TODO: identify netcdf3 files and don't use the global lock - # for them - lock = HDF5_LOCK - elif engine in {'h5netcdf', 'pynio'}: - lock = HDF5_LOCK - else: - lock = False - return lock - - def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -130,29 +111,14 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data -def _get_lock(engine, scheduler, format, path_or_file): - """ Get the lock(s) that apply to a particular scheduler/engine/format""" - - locks = [] - if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']: - locks.append(HDF5_LOCK) - locks.append(_get_scheduler_lock(scheduler, path_or_file)) - - # When we have more than one lock, use the CombinedLock wrapper class - lock = CombinedLock(locks) if len(locks) > 1 else locks[0] - - return lock - - def _finalize_store(write, store): """ Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first - store.sync() store.close() def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -204,12 +170,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, it used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -235,6 +200,14 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ + if autoclose is not None: + warnings.warn( + 'The autoclose argument is no longer used by ' + 'xarray.open_dataset() and is now ignored; it will be removed in ' + 'xarray v0.12. If necessary, you can control the maximum number ' + 'of simultaneous open files with ' + 'xarray.set_options(file_cache_maxsize=...).', + FutureWarning, stacklevel=2) if mask_and_scale is None: mask_and_scale = not engine == 'pseudonetcdf' @@ -272,18 +245,11 @@ def maybe_decode_store(store, lock=False): mask_and_scale, decode_times, concat_characters, decode_coords, engine, chunks, drop_variables) name_prefix = 'open_dataset-%s' % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: ds2 = ds - # protect so that dataset store isn't necessarily closed, e.g., - # streams like BytesIO can't be reopened - # datastore backend is responsible for determining this capability - if store._autoclose: - store.close() - return ds2 if isinstance(filename_or_obj, path_type): @@ -314,36 +280,28 @@ def maybe_decode_store(store, lock=False): engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': - store = backends.NetCDF4DataStore.open(filename_or_obj, - group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.NetCDF4DataStore.open( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj, - **backend_kwargs) + store = backends.PydapDataStore.open( + filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.H5NetCDFStore( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.NioDataStore( + filename_or_obj, lock=lock, **backend_kwargs) elif engine == 'pseudonetcdf': store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, autoclose=autoclose, **backend_kwargs) + filename_or_obj, lock=lock, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) - if lock is None: - lock = _default_lock(filename_or_obj, engine) with close_on_error(store): - return maybe_decode_store(store, lock) + return maybe_decode_store(store) else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' @@ -355,7 +313,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -390,10 +348,6 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -409,12 +363,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -490,7 +443,7 @@ def close(self): def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', preprocess=None, engine=None, lock=None, data_vars='all', coords='different', - autoclose=False, parallel=False, **kwargs): + autoclose=None, parallel=False, **kwargs): """Open multiple files as a single dataset. Requires dask to be installed. See documentation for details on dask [1]. @@ -537,15 +490,11 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. - lock : False, True or threading.Lock, optional - This argument is passed on to :py:func:`dask.array.from_array`. By - default, a per-variable lock is used when reading data from netCDF - files with the netcdf4 and h5netcdf engines to avoid issues with - concurrent access when using dask's multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. data_vars : {'minimal', 'different', 'all' or list of str}, optional These data variables will be concatenated together: * 'minimal': Only data variables in which the dimension already @@ -604,9 +553,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, if not paths: raise IOError('no files to open') - if lock is None: - lock = _default_lock(paths[0], engine) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs) @@ -656,19 +602,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None, - compute=True): + engine=None, encoding=None, unlimited_dims=None, compute=True, + multifile=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file See `Dataset.to_netcdf` for full API docs. - The ``writer`` argument is only for the private use of save_mfdataset. + The ``multifile`` argument is only for the private use of save_mfdataset. """ if isinstance(path_or_file, path_type): path_or_file = str(path_or_file) + if encoding is None: encoding = {} + if path_or_file is None: if engine is None: engine = 'scipy' @@ -676,6 +624,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) + if not compute: + raise NotImplementedError( + 'to_netcdf() with compute=False is not yet implemented when ' + 'returning bytes') elif isinstance(path_or_file, basestring): if engine is None: engine = _get_default_engine(path_or_file) @@ -695,45 +647,78 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if format is not None: format = format.upper() - # if a writer is provided, store asynchronously - sync = writer is None - # handle scheduler specific logic scheduler = _get_scheduler() have_chunks = any(v.chunks for v in dataset.variables.values()) - if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and - engine != 'netcdf4'): + + autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] + if autoclose and engine == 'scipy': raise NotImplementedError("Writing netCDF files with the %s backend " "is not currently supported with dask's %s " "scheduler" % (engine, scheduler)) - lock = _get_lock(engine, scheduler, format, path_or_file) - autoclose = (have_chunks and - scheduler in ['distributed', 'multiprocessing']) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer, - autoclose=autoclose, lock=lock) + kwargs = dict(autoclose=True) if autoclose else {} + store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) if isinstance(unlimited_dims, basestring): unlimited_dims = [unlimited_dims] + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals try: - dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims, compute=compute) + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store(dataset, store, writer, encoding=encoding, + unlimited_dims=unlimited_dims) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + if path_or_file is None: + store.sync() return target.getvalue() finally: - if sync and isinstance(path_or_file, basestring): + if not multifile and compute: store.close() if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) + + +def dump_to_store(dataset, store, writer=None, encoder=None, + encoding=None, unlimited_dims=None): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, + unlimited_dims=unlimited_dims) - if not sync: - return store def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, engine=None, compute=True): @@ -816,22 +801,22 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() if compute else None - stores = [to_netcdf(ds, path, mode, format, group, engine, writer, - compute=compute) - for ds, path, group in zip(datasets, paths, groups)] - - if not compute: - import dask - return dask.delayed(stores) + writers, stores = zip(*[ + to_netcdf(ds, path, mode, format, group, engine, compute=compute, + multifile=True) + for ds, path, group in zip(datasets, paths, groups)]) try: - delayed = writer.sync(compute=compute) - for store in stores: - store.sync() + writes = [w.sync(compute=compute) for w in writers] finally: - for store in stores: - store.close() + if compute: + for store in stores: + store.close() + + if not compute: + import dask + return dask.delayed([dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores)]) def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, @@ -852,13 +837,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, store = backends.ZarrStore.open_group(store=store, mode=mode, synchronizer=synchronizer, - group=group, writer=None) + group=group) - # I think zarr stores should always be sync'd immediately + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute) + dump_to_store(dataset, store, writer, encoding=encoding) + writes = writer.sync(compute=compute) if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) return store diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 99f7698ee92..405d989f4af 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,14 +1,10 @@ from __future__ import absolute_import, division, print_function -import contextlib import logging -import multiprocessing -import threading import time import traceback import warnings from collections import Mapping, OrderedDict -from functools import partial import numpy as np @@ -17,13 +13,6 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -# Import default lock -try: - from dask.utils import SerializableLock - HDF5_LOCK = SerializableLock() -except ImportError: - HDF5_LOCK = threading.Lock() - # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -31,62 +20,6 @@ NONE_VAR_NAME = '__values__' -def _get_scheduler(get=None, collection=None): - """ Determine the dask scheduler that is being used. - - None is returned if not dask scheduler is active. - - See also - -------- - dask.base.get_scheduler - """ - try: - # dask 0.18.1 and later - from dask.base import get_scheduler - actual_get = get_scheduler(get, collection) - except ImportError: - try: - from dask.utils import effective_get - actual_get = effective_get(get, collection) - except ImportError: - return None - - try: - from dask.distributed import Client - if isinstance(actual_get.__self__, Client): - return 'distributed' - except (ImportError, AttributeError): - try: - import dask.multiprocessing - if actual_get == dask.multiprocessing.get: - return 'multiprocessing' - else: - return 'threaded' - except ImportError: - return 'threaded' - - -def _get_scheduler_lock(scheduler, path_or_file=None): - """ Get the appropriate lock for a certain situation based onthe dask - scheduler used. - - See Also - -------- - dask.utils.get_scheduler_lock - """ - - if scheduler == 'distributed': - from dask.distributed import Lock - return Lock(path_or_file) - elif scheduler == 'multiprocessing': - return multiprocessing.Lock() - elif scheduler == 'threaded': - from dask.utils import SerializableLock - return SerializableLock() - else: - return threading.Lock() - - def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -133,39 +66,6 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, time.sleep(1e-3 * next_delay) -class CombinedLock(object): - """A combination of multiple locks. - - Like a locked door, a CombinedLock is locked if any of its constituent - locks are locked. - """ - - def __init__(self, locks): - self.locks = tuple(set(locks)) # remove duplicates - - def acquire(self, *args): - return all(lock.acquire(*args) for lock in self.locks) - - def release(self, *args): - for lock in self.locks: - lock.release(*args) - - def __enter__(self): - for lock in self.locks: - lock.__enter__() - - def __exit__(self, *args): - for lock in self.locks: - lock.__exit__(*args) - - @property - def locked(self): - return any(lock.locked for lock in self.locks) - - def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) - - class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): @@ -174,9 +74,6 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = None - _ds = None - _isopen = False def __iter__(self): return iter(self.variables) @@ -259,7 +156,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=HDF5_LOCK): + def __init__(self, lock=None): self.sources = [] self.targets = [] self.lock = lock @@ -274,6 +171,9 @@ def add(self, source, target): def sync(self, compute=True): if self.sources: import dask.array as da + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernable difference in perforance, e.g., + # targets = [dask.delayed(t) for t in self.targets] delayed_store = da.store(self.sources, self.targets, lock=self.lock, compute=compute, flush=True) @@ -283,11 +183,6 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None, lock=HDF5_LOCK): - if writer is None: - writer = ArrayWriter(lock=lock) - self.writer = writer - self.delayed_store = None def encode(self, variables, attributes): """ @@ -329,12 +224,6 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self, compute=True): - if self._isopen and self._autoclose: - # datastore will be reopened during write - self.close() - self.delayed_store = self.writer.sync(compute=compute) - def store_dataset(self, dataset): """ in stores, variables are all variables AND coordinates @@ -345,7 +234,7 @@ def store_dataset(self, dataset): self.store(dataset, dataset.attrs) def store(self, variables, attributes, check_encoding_set=frozenset(), - unlimited_dims=None): + writer=None, unlimited_dims=None): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -361,16 +250,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. """ + if writer is None: + writer = ArrayWriter() variables, attributes = self.encode(variables, attributes) self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, + self.set_variables(variables, check_encoding_set, writer, unlimited_dims=unlimited_dims) def set_attributes(self, attributes): @@ -386,7 +278,7 @@ def set_attributes(self, attributes): for k, v in iteritems(attributes): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data @@ -399,6 +291,7 @@ def set_variables(self, variables, check_encoding_set, check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -410,7 +303,7 @@ def set_variables(self, variables, check_encoding_set, target, source = self.prepare_variable( name, v, check, unlimited_dims=unlimited_dims) - self.writer.add(source, target) + writer.add(source, target) def set_dimensions(self, variables, unlimited_dims=None): """ @@ -457,87 +350,3 @@ def encode(self, variables, attributes): attributes = OrderedDict([(k, self.encode_attribute(v)) for k, v in attributes.items()]) return variables, attributes - - -class DataStorePickleMixin(object): - """Subclasses must define `ds`, `_opener` and `_mode` attributes. - - Do not subclass this class: it is not part of xarray's external API. - """ - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - del state['_isopen'] - if self._mode == 'w': - # file has already been created, don't override when restoring - state['_mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._ds = None - self._isopen = False - - @property - def ds(self): - if self._ds is not None and self._isopen: - return self._ds - ds = self._opener(mode=self._mode) - self._isopen = True - return ds - - @contextlib.contextmanager - def ensure_open(self, autoclose=None): - """ - Helper function to make sure datasets are closed and opened - at appropriate times to avoid too many open file errors. - - Use requires `autoclose=True` argument to `open_mfdataset`. - """ - - if autoclose is None: - autoclose = self._autoclose - - if not self._isopen: - try: - self._ds = self._opener() - self._isopen = True - yield - finally: - if autoclose: - self.close() - else: - yield - - def assert_open(self): - if not self._isopen: - raise AssertionError('internal failure: file must be open ' - 'if `autoclose=True` is used.') - - -class PickleByReconstructionWrapper(object): - - def __init__(self, opener, file, mode='r', **kwargs): - self.opener = partial(opener, file, mode=mode, **kwargs) - self.mode = mode - self._ds = None - - @property - def value(self): - self._ds = self.opener() - return self._ds - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - if self.mode == 'w': - # file has already been created, don't override when restoring - state['mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def close(self): - self._ds.close() diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py new file mode 100644 index 00000000000..a93285370b2 --- /dev/null +++ b/xarray/backends/file_manager.py @@ -0,0 +1,206 @@ +import threading + +from ..core import utils +from ..core.options import OPTIONS +from .lru_cache import LRUCache + + +# Global cache for storing open files. +FILE_CACHE = LRUCache( + OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, 'file cache must be at least size one' + + +_DEFAULT_MODE = utils.ReprObject('') + + +class FileManager(object): + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self): + """Acquire the file object from this manager.""" + raise NotImplementedError + + def close(self, needs_lock=True): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError + + +class CachingFileManager(FileManager): + """Wrapper for automatically opening and closing file objects. + + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. + + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + + Example usage: + + manager = FileManager(open, 'example.txt', mode='w') + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + + """ + + def __init__(self, opener, *args, **keywords): + """Initialize a FileManager. + + Parameters + ---------- + opener : callable + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). All arguments must be + hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. + """ + # TODO: replace with real keyword arguments when we drop Python 2 + # support + mode = keywords.pop('mode', _DEFAULT_MODE) + kwargs = keywords.pop('kwargs', None) + lock = keywords.pop('lock', None) + cache = keywords.pop('cache', FILE_CACHE) + if keywords: + raise TypeError('FileManager() got unexpected keyword arguments: ' + '%s' % list(keywords)) + + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None or lock is False + self._lock = threading.Lock() if self._default_lock else lock + self._cache = cache + self._key = self._make_key() + + def _make_key(self): + """Make a key for caching files in the LRU cache.""" + value = (self._opener, + self._args, + self._mode, + tuple(sorted(self._kwargs.items()))) + return _HashedSequence(value) + + def acquire(self): + """Acquiring a file object from the manager. + + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a reentrant lock, which ensures that it is + thread-safe. You can safely acquire a file in multiple threads at the + same time, as long as the underlying file object is thread-safe. + + Returns + ------- + An open file object, as returned by ``opener(*args, **kwargs)``. + """ + with self._lock: + try: + file = self._cache[self._key] + except KeyError: + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs['mode'] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == 'w': + # ensure file doesn't get overriden when opened again + self._mode = 'a' + self._key = self._make_key() + self._cache[self._key] = file + return file + + def _close(self): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def close(self, needs_lock=True): + """Explicitly close any associated file object (if necessary).""" + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + if needs_lock: + with self._lock: + self._close() + else: + self._close() + + def __getstate__(self): + """State for pickling.""" + lock = None if self._default_lock else self._lock + return (self._opener, self._args, self._mode, self._kwargs, lock) + + def __setstate__(self, state): + """Restore from a pickle.""" + opener, args, mode, kwargs, lock = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) + + +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. + + Based on what Python uses internally in functools.lru_cache. + + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ + + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) + + def __hash__(self): + return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface. + """ + def __init__(self, value): + self._value = value + + def acquire(self): + return self._value + + def close(self, needs_lock=True): + del needs_lock # ignored + self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 959cd221734..59cd4e84793 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,11 +8,12 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import ( - HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) +from .common import WritableCFDataStore +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( - BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_require_group) + BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, + _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -25,8 +26,9 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] def maybe_decode_bytes(txt): @@ -61,104 +63,102 @@ def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_require_group( + ds = _nc4_require_group( ds, group, mode, create_group=_h5netcdf_create_group) + return GroupWrapper(ds) -class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): +class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False, lock=HDF5_LOCK): + lock=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, - group=group) - self._ds = opener() - if autoclose: - raise NotImplementedError('autoclose=True is not implemented ' - 'for the h5netcdf backend pending ' - 'further exploration, e.g., bug fixes ' - '(in h5netcdf?)') - self._autoclose = False - self._isopen = True + self._manager = CachingFileManager( + _open_h5netcdf_group, filename, mode=mode, + kwargs=dict(group=group)) + + if lock is None: + if mode == 'r': + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) + self.format = format - self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @property + def ds(self): + return self._manager.acquire().value def open_store_variable(self, name, var): import h5py - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - H5NetCDFArrayWrapper(name, self)) - attrs = _read_attributes(var) - - # netCDF4 specific encoding - encoding = { - 'chunksizes': var.chunks, - 'fletcher32': var.fletcher32, - 'shuffle': var.shuffle, - } - # Convert h5py-style compression options to NetCDF4-Python - # style, if possible - if var.compression == 'gzip': - encoding['zlib'] = True - encoding['complevel'] = var.compression_opts - elif var.compression is not None: - encoding['compression'] = var.compression - encoding['compression_opts'] = var.compression_opts - - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - - vlen_dtype = h5py.check_dtype(vlen=var.dtype) - if vlen_dtype is unicode_type: - encoding['dtype'] = str - elif vlen_dtype is not None: # pragma: no cover - # xarray doesn't support writing arbitrary vlen dtypes yet. - pass - else: - encoding['dtype'] = var.dtype + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return FrozenOrderedDict(_read_attributes(self.ds)) + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return self.ds.dimensions + return self.ds.dimensions def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v is None} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if is_unlimited: - self.ds.dimensions[name] = None - self.ds.resize_dimension(name, length) - else: - self.ds.dimensions[name] = length + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self.ds.attrs[key] = value + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -226,18 +226,11 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the h5netcdf backend yet') - with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync(compute=compute) - self.ds.sync() - - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if not ds._closed: - ds.close() - self._isopen = False + def sync(self): + self.ds.sync() + # if self.autoclose: + # self.close() + # super(H5NetCDFStore, self).sync(compute=compute) + + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py new file mode 100644 index 00000000000..f633280ef1d --- /dev/null +++ b/xarray/backends/locks.py @@ -0,0 +1,191 @@ +import multiprocessing +import threading +import weakref + +try: + from dask.utils import SerializableLock +except ImportError: + # no need to worry about serializing the lock + SerializableLock = threading.Lock + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_distributed_lock(key): + from dask.distributed import Lock + return Lock(key) + + +_LOCK_MAKERS = { + None: _get_threaded_lock, + 'threaded': _get_threaded_lock, + 'multiprocessing': _get_multiprocessing_lock, + 'distributed': _get_distributed_lock, +} + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + return _LOCK_MAKERS[scheduler] + + +def _get_scheduler(get=None, collection=None): + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See also + -------- + dask.base.get_scheduler + """ + try: + # dask 0.18.1 and later + from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) + except ImportError: + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + + +def get_write_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py new file mode 100644 index 00000000000..321a1ca4da4 --- /dev/null +++ b/xarray/backends/lru_cache.py @@ -0,0 +1,91 @@ +import collections +import threading + +from ..core.pycompat import move_to_end + + +class LRUCache(collections.MutableMapping): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. + """ + def __init__(self, maxsize, on_evict=None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict: callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + if not isinstance(maxsize, int): + raise TypeError('maxsize must be an integer') + if maxsize < 0: + raise ValueError('maxsize must be non-negative') + self._maxsize = maxsize + self._on_evict = on_evict + self._cache = collections.OrderedDict() + self._lock = threading.RLock() + + def __getitem__(self, key): + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + move_to_end(self._cache, key) + return value + + def _enforce_size_limit(self, capacity): + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key, value): + with self._lock: + if key in self._cache: + # insert the new value at the end + del self._cache[key] + self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._enforce_size_limit(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) + + def __delitem__(self, key): + del self._cache[key] + + def __iter__(self): + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self): + return len(self._cache) + + @property + def maxsize(self): + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size): + """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError('maxsize must be non-negative') + with self._lock: + self._enforce_size_limit(size) + self._maxsize = size diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index dcf092557b8..195d4647534 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -17,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) def get_attrs(self): return self._attributes diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index aa19633020b..08ba085b77e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -13,8 +13,10 @@ from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, - find_root, robust_getitem) + BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .locks import (NETCDFC_LOCK, HDF5_LOCK, + combine_locks, ensure_lock, get_write_lock) +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian @@ -25,6 +27,9 @@ '|': 'native'} +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + class BaseNetCDF4Array(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -42,12 +47,13 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): + with self.datastore.lock: data = self.get_array() data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] @@ -63,20 +69,22 @@ def _getitem(self, key): else: getitem = operator.getitem - with self.datastore.ensure_open(autoclose=True): - try: - array = getitem(self.get_array(), key) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative - # error message. This is most often called when an unsorted - # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) + original_array = self.get_array() + + try: + with self.datastore.lock: + array = getitem(original_array, key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) return array @@ -223,7 +231,17 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, return encoding -def _open_netcdf4_group(filename, mode, group=None, **kwargs): +class GroupWrapper(object): + """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): + self.value = value + + def close(self): + # netCDF4 only allows closing the root group + find_root(self.value).close() + + +def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): import netCDF4 as nc4 ds = nc4.Dataset(filename, mode=mode, **kwargs) @@ -233,7 +251,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return ds + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -279,40 +297,33 @@ def _set_nc_attribute(obj, key, value): obj.setncattr(key, value) -class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): +class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False, lock=HDF5_LOCK): - - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False): + import netCDF4 - _disable_auto_decode_group(netcdf4_dataset) + if isinstance(manager, netCDF4.Dataset): + _disable_auto_decode_group(manager) + manager = DummyFileManager(GroupWrapper(manager)) - self._ds = netcdf4_dataset - self._autoclose = autoclose - self._isopen = True + self._manager = manager self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._mode = mode = 'a' if mode == 'w' else mode - if opener: - self._opener = functools.partial(opener, mode=self._mode) - else: - self._opener = opener - super(NetCDF4DataStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False, - autoclose=False, lock=HDF5_LOCK): - import netCDF4 as nc4 + clobber=True, diskless=False, persist=False, + lock=None, lock_maker=None, autoclose=False): + import netCDF4 if (len(filename) == 88 and - LooseVersion(nc4.__version__) < "1.3.1"): + LooseVersion(netCDF4.__version__) < "1.3.1"): warnings.warn( 'A segmentation fault may occur when the ' 'file path has exactly 88 characters as it does ' @@ -323,86 +334,91 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' - opener = functools.partial(_open_netcdf4_group, filename, mode=mode, - group=group, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - ds = opener() - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose, lock=lock) - def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None + if lock is None: + if mode == 'r': + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith('NETCDF4'): + base_lock = NETCDF4_PYTHON_LOCK else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - encoding['dtype'] = var.dtype + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_write_lock(filename)]) + + manager = CachingFileManager( + _open_netcdf4_group, filename, lock, mode=mode, + kwargs=dict(group=group, clobber=clobber, diskless=diskless, + persist=persist, format=format)) + return cls(manager, lock=lock, autoclose=autoclose) + + @property + def ds(self): + return self._manager.acquire().value + + def open_store_variable(self, name, var): + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - iteritems(self.ds.variables)) + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in + iteritems(self.ds.variables)) return dsvars def get_attrs(self): - with self.ensure_open(autoclose=True): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - with self.ensure_open(autoclose=True): - dims = FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) return dims def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, size=dim_length) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - _set_nc_attribute(self.ds, key, value) - - def set_variables(self, *args, **kwargs): - with self.ensure_open(autoclose=False): - super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) @@ -460,15 +476,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync(compute=compute) - self.ds.sync() + def sync(self): + self.ds.sync() - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if ds._isopen: - ds.close() - self._isopen = False + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 3d846916740..e4691d1f7e1 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,14 +1,18 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.pycompat import OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from ..core.utils import Frozen +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock + + +# psuedonetcdf can invoke netCDF libraries internally +PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) class PncArrayWrapper(BackendArray): @@ -21,7 +25,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.dtype) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -30,57 +33,55 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] -class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin): +class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ @classmethod - def open(cls, filename, format=None, writer=None, - autoclose=False, **format_kwds): + def open(cls, filename, lock=None, **format_kwds): from PseudoNetCDF import pncopen - opener = functools.partial(pncopen, filename, **format_kwds) - ds = opener() - mode = format_kwds.get('mode', 'r') - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) - def __init__(self, pnc_dataset, mode='r', writer=None, opener=None, - autoclose=False): + keywords = dict(kwargs=format_kwds) + # only include mode if explicitly passed + mode = format_kwds.pop('mode', None) + if mode is not None: + keywords['mode'] = mode + + if lock is None: + lock = PNETCDF_LOCK + + manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) + return cls(manager, lock) - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=None): + self._manager = manager + self.lock = ensure_lock(lock) - self._ds = pnc_dataset - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode - super(PseudoNetCDFDataStore, self).__init__() + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - data = indexing.LazilyOuterIndexedArray( - PncArrayWrapper(name, self) - ) + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) return Variable(var.dimensions, data, attrs) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return ((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(dict([(k, getattr(self.ds, k)) - for k in self.ds.ncattrs()])) + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -90,6 +91,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 98b76928597..574fff744e3 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,13 +1,20 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import ( + HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock) + + +# PyNIO can invoke netCDF libraries internally +# Add a dedicated lock just in case NCL as well isn't thread-safe. +NCL_LOCK = SerializableLock() +PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) class NioArrayWrapper(BackendArray): @@ -20,7 +27,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.typecode()) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -28,46 +34,45 @@ def __getitem__(self, key): key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - array = self.get_array() + array = self.get_array() + with self.datastore.lock: if key == () and self.ndim == 0: return array.get_value() - return array[key] -class NioDataStore(AbstractDataStore, DataStorePickleMixin): +class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', autoclose=False): + def __init__(self, filename, mode='r', lock=None): import Nio - opener = functools.partial(Nio.open_file, filename, mode=mode) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if lock is None: + lock = PYNIO_LOCK + self.lock = ensure_lock(lock) + self._manager = CachingFileManager( + Nio.open_file, filename, lock=lock, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') + @property + def ds(self): + return self._manager.acquire() + def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.attributes) + return Frozen(self.ds.attributes) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -76,6 +81,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 44cca9aaaf8..5746b4e748d 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -8,14 +8,13 @@ from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray, PickleByReconstructionWrapper +from .common import BackendArray +from .file_manager import CachingFileManager +from .locks import SerializableLock -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock -RASTERIO_LOCK = Lock() +# TODO: should this be GDAL_LOCK instead? +RASTERIO_LOCK = SerializableLock() _ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' 'valid on rasterio files. Try to load your data with ds.load()' @@ -25,18 +24,22 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, riods): - self.riods = riods - self._shape = (riods.value.count, riods.value.height, - riods.value.width) - self._ndims = len(self.shape) + def __init__(self, manager): + self.manager = manager - @property - def dtype(self): - dtypes = self.riods.value.dtypes + # cannot save riods as an attribute: this would break pickleability + riods = manager.acquire() + + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') - return np.dtype(dtypes[0]) + self._dtype = np.dtype(dtypes[0]) + + @property + def dtype(self): + return self._dtype @property def shape(self): @@ -108,7 +111,8 @@ def _getitem(self, key): stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: - out = self.riods.value.read(band_key, window=window) + riods = self.manager.acquire() + out = riods.read(band_key, window=window) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) @@ -203,7 +207,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') + manager = CachingFileManager(rasterio.open, filename, mode='r') + riods = manager.acquire() if cache is None: cache = chunks is None @@ -211,20 +216,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.value.count < 1: + if riods.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.value.indexes) + coords['band'] = np.asarray(riods.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.value.affine + transform = riods.affine else: - transform = riods.value.transform + transform = riods.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.value.width, riods.value.height + nx, ny = riods.width, riods.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -234,57 +239,60 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # 2d coordinates parse = False if (parse_coordinates is None) else parse_coordinates if parse: - warnings.warn("The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) # Attributes attrs = dict() # Affine transformation matrix (always available) # This describes coefficients mapping pixel coordinates to CRS # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods.value, 'crs') and riods.value.crs: + if hasattr(riods, 'crs') and riods.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.value.crs.to_string() - if hasattr(riods.value, 'res'): + attrs['crs'] = riods.crs.to_string() + if hasattr(riods, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.value.res - if hasattr(riods.value, 'is_tiled'): + attrs['res'] = riods.res + if hasattr(riods, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.value.is_tiled) - if hasattr(riods.value, 'nodatavals'): + attrs['is_tiled'] = np.uint8(riods.is_tiled) + if hasattr(riods, 'nodatavals'): # The nodata values for the raster bands - attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.value.nodatavals]) + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.value.driver + driver = riods.driver if driver in parsers: - meta = parsers[driver](riods.value.tags(ns=driver)) + meta = parsers[driver](riods.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise if (isinstance(v, (list, np.ndarray)) and - len(v) == riods.value.count): + len(v) == riods.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods)) + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) - if cache and (chunks is None): + if cache and chunks is None: data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), @@ -306,6 +314,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=lock) # Make the file closeable - result._file_obj = riods + result._file_obj = manager return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index cd84431f6b7..b009342efb6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import functools import warnings from distutils.version import LooseVersion from io import BytesIO @@ -11,7 +10,9 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict -from .common import BackendArray, DataStorePickleMixin, WritableCFDataStore +from .common import BackendArray, WritableCFDataStore +from .locks import get_write_lock +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -40,31 +41,26 @@ def __init__(self, variable_name, datastore): str(array.dtype.itemsize)) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data def __getitem__(self, key): - with self.datastore.ensure_open(autoclose=True): - data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - return np.array(data, dtype=self.dtype, copy=copy) + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): - data = self.datastore.ds.variables[self.variable_name] - try: - data[key] = value - except TypeError: - if key is Ellipsis: - # workaround for GH: scipy/scipy#6880 - data[:] = value - else: - raise + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -106,7 +102,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise -class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): +class ScipyDataStore(WritableCFDataStore): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -116,7 +112,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False, lock=None): + mmap=None, lock=None): import scipy import scipy.io @@ -140,34 +136,38 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - opener = functools.partial(_open_scipy_netcdf, - filename=filename_or_obj, - mode=mode, mmap=mmap, version=version) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if (lock is None and mode != 'r' and + isinstance(filename_or_obj, basestring)): + lock = get_write_lock(filename_or_obj) + + if isinstance(filename_or_obj, basestring): + manager = CachingFileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, + kwargs=dict(mmap=mmap, version=version)) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager - super(ScipyDataStore, self).__init__(writer, lock=lock) + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(_decode_attrs(self.ds._attributes)) + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -176,22 +176,20 @@ def get_encoding(self): return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, dim_length) + if name in self.ds.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def encode_variable(self, variable): variable = encode_nc3_variable(variable) @@ -219,27 +217,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the scipy backend yet') - with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync(compute=compute) - self.ds.flush() + def sync(self): + self.ds.sync() def close(self): - self.ds.close() - self._isopen = False - - def __exit__(self, type, value, tb): - self.close() - - def __setstate__(self, state): - filename = state['_opener'].keywords['filename'] - if hasattr(filename, 'seek'): - # it's a file-like object - # seek to the start of the file so scipy can read it - filename.seek(0) - super(ScipyDataStore, self).__setstate__(state) - self._ds = None - self._isopen = False + self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 47b90c8a617..5f19c826289 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -217,8 +217,7 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - writer=None): + def open_group(cls, store, mode='r', synchronizer=None, group=None): import zarr min_zarr = '2.2' @@ -230,24 +229,14 @@ def open_group(cls, store, mode='r', synchronizer=None, group=None, "#installation" % min_zarr) zarr_group = zarr.open_group(store=store, mode=mode, synchronizer=synchronizer, path=group) - return cls(zarr_group, writer=writer) + return cls(zarr_group) - def __init__(self, zarr_group, writer=None): + def __init__(self, zarr_group): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path - if writer is None: - # by default, we should not need a lock for writing zarr because - # we do not (yet) allow overlapping chunks during write - zarr_writer = ArrayWriter(lock=False) - else: - zarr_writer = writer - - # do we need to define attributes for all of the opener keyword args? - super(ZarrStore, self).__init__(zarr_writer) - def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, @@ -334,8 +323,8 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) - def sync(self, compute=True): - self.delayed_store = self.writer.sync(compute=compute) + def sync(self): + pass def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4ade15825c6..c8586d1d408 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1161,27 +1161,12 @@ def reset_coords(self, names=None, drop=False, inplace=False): del obj._variables[name] return obj - def dump_to_store(self, store, encoder=None, sync=True, encoding=None, - unlimited_dims=None, compute=True): + def dump_to_store(self, store, **kwargs): """Store dataset contents to a backends.*DataStore object.""" - if encoding is None: - encoding = {} - variables, attrs = conventions.encode_dataset_coordinates(self) - - check_encoding = set() - for k, enc in encoding.items(): - # no need to shallow copy the variable again; that already happened - # in encode_dataset_coordinates - variables[k].encoding = enc - check_encoding.add(k) - - if encoder: - variables, attrs = encoder(variables, attrs) - - store.store(variables, attrs, check_encoding, - unlimited_dims=unlimited_dims) - if sync: - store.sync(compute=compute) + from ..backends.api import dump_to_store + # TODO: rename and/or cleanup this method to make it more consistent + # with to_netcdf() + return dump_to_store(self, store, **kwargs) def to_netcdf(self, path=None, mode='w', format=None, group=None, engine=None, encoding=None, unlimited_dims=None, diff --git a/xarray/core/options.py b/xarray/core/options.py index a6118f02ed3..04ea0be7172 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,11 +1,43 @@ from __future__ import absolute_import, division, print_function +DISPLAY_WIDTH = 'display_width' +ARITHMETIC_JOIN = 'arithmetic_join' +ENABLE_CFTIMEINDEX = 'enable_cftimeindex' +FILE_CACHE_MAXSIZE = 'file_cache_maxsize' +CMAP_SEQUENTIAL = 'cmap_sequential' +CMAP_DIVERGENT = 'cmap_divergent' + OPTIONS = { - 'display_width': 80, - 'arithmetic_join': 'inner', - 'enable_cftimeindex': False, - 'cmap_sequential': 'viridis', - 'cmap_divergent': 'RdBu_r', + DISPLAY_WIDTH: 80, + ARITHMETIC_JOIN: 'inner', + ENABLE_CFTIMEINDEX: False, + FILE_CACHE_MAXSIZE: 128, + CMAP_SEQUENTIAL: 'viridis', + CMAP_DIVERGENT: 'RdBu_r', +} + +_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) + + +def _positive_integer(value): + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + DISPLAY_WIDTH: _positive_integer, + ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), + FILE_CACHE_MAXSIZE: _positive_integer, +} + + +def _set_file_cache_maxsize(value): + from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value + + +_SETTERS = { + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, } @@ -21,6 +53,10 @@ class set_options(object): - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` for time indexes with non-standard calendars or dates outside the Timestamp-valid range. Default: ``False``. + - ``file_cache_maxsize``: maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than your + system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. + Default: 128. - ``cmap_sequential``: colormap to use for nondivergent data plots. Default: ``viridis``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) @@ -28,8 +64,7 @@ class set_options(object): Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) - - You can use ``set_options`` either as a context manager: +f You can use ``set_options`` either as a context manager: >>> ds = xr.Dataset({'x': np.arange(1000)}) >>> with xr.set_options(display_width=40): @@ -47,16 +82,26 @@ class set_options(object): """ def __init__(self, **kwargs): - invalid_options = {k for k in kwargs if k not in OPTIONS} - if invalid_options: - raise ValueError('argument names %r are not in the set of valid ' - 'options %r' % (invalid_options, set(OPTIONS))) self.old = OPTIONS.copy() - OPTIONS.update(kwargs) + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + 'argument name %r is not in the set of valid options %r' + % (k, set(OPTIONS))) + if k in _VALIDATORS and not _VALIDATORS[k](v): + raise ValueError( + 'option %r given an invalid value: %r' % (k, v)) + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) def __enter__(self): return def __exit__(self, type, value, traceback): OPTIONS.clear() - OPTIONS.update(self.old) + self._apply_update(self.old) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 78c26f1e92f..b980bc279b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -28,6 +28,9 @@ def itervalues(d): import builtins from urllib.request import urlretrieve from inspect import getfullargspec as getargspec + + def move_to_end(ordered_dict, key): + ordered_dict.move_to_end(key) else: # pragma: no cover # Python 2 basestring = basestring # noqa @@ -50,6 +53,11 @@ def itervalues(d): from urllib import urlretrieve from inspect import getargspec + def move_to_end(ordered_dict, key): + value = ordered_dict[key] + del ordered_dict[key] + ordered_dict[key] = value + integer_types = native_int_types + (np.integer,) try: @@ -76,7 +84,6 @@ def itervalues(d): except ImportError as e: path_type = () - try: from contextlib import suppress except ImportError: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0d97ed70fa3..43811942d5f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -8,7 +8,6 @@ import shutil import sys import tempfile -import unittest import warnings from io import BytesIO @@ -20,13 +19,13 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import ( - PickleByReconstructionWrapper, robust_getitem) +from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing from xarray.core.pycompat import ( - PY2, ExitStack, basestring, dask_array_type, iteritems) + ExitStack, basestring, dask_array_type, iteritems) +from xarray.core.options import set_options from xarray.tests import mock from . import ( @@ -138,7 +137,6 @@ class NetCDF3Only(object): class DatasetIOTestCases(object): - autoclose = False engine = None file_format = None @@ -172,8 +170,7 @@ def save(self, dataset, path, **kwargs): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine=self.engine, autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds def test_zero_dimensional_variable(self): @@ -1159,10 +1156,10 @@ def test_already_open_dataset(self): v[...] = 42 nc = nc4.Dataset(tmp_file, mode='r') - with backends.NetCDF4DataStore(nc, autoclose=False) as store: - with open_dataset(store) as ds: - expected = Dataset({'x': ((), 42)}) - assert_identical(expected, ds) + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({'x': ((), 42)}) + assert_identical(expected, ds) def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: @@ -1181,7 +1178,6 @@ def test_read_variable_len_strings(self): @requires_netCDF4 class NetCDF4DataTest(BaseNetCDF4Test): - autoclose = False @contextlib.contextmanager def create_store(self): @@ -1247,9 +1243,13 @@ def test_setncattr_string(self): totest.attrs['bar']) assert one_string == totest.attrs['baz'] - -class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): - autoclose = True + def test_autoclose_future_warning(self): + data = create_test_data() + with create_tmp_file() as tmp_file: + self.save(data, tmp_file) + with pytest.warns(FutureWarning): + with self.open(tmp_file, autoclose=True) as actual: + assert_identical(data, actual) @requires_netCDF4 @@ -1290,10 +1290,6 @@ def test_write_inconsistent_chunks(self): assert actual['y'].encoding['chunksizes'] == (100, 50) -class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): - autoclose = True - - @requires_zarr class BaseZarrTest(CFEncodedDataTest): @@ -1571,19 +1567,14 @@ def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') - @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') - def test_bytesio_pickle(self): + def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) - fobj = BytesIO(data.to_netcdf()) - with open_dataset(fobj, autoclose=self.autoclose) as ds: + fobj = data.to_netcdf() + with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) -class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): - autoclose = True - - @requires_scipy class ScipyFileObjectTest(ScipyWriteTest): engine = 'scipy' @@ -1649,10 +1640,6 @@ def test_nc4_scipy(self): open_dataset(tmp_file, engine='scipy') -class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): - autoclose = True - - @requires_netCDF4 class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only): engine = 'netcdf4' @@ -1673,10 +1660,6 @@ def test_encoding_kwarg_vlen_string(self): pass -class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): - autoclose = True - - @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, object): @@ -1691,11 +1674,6 @@ def create_store(self): yield store -class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( - NetCDF4ClassicViaNetCDF4DataTest): - autoclose = True - - @requires_scipy_or_netCDF4 class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy @@ -1772,10 +1750,6 @@ def test_encoding_unlimited_dims(self): assert_equal(ds, actual) -class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): - autoclose = True - - @requires_h5netcdf @requires_netCDF4 class H5NetCDFDataTest(BaseNetCDF4Test): @@ -1789,8 +1763,11 @@ def create_store(self): @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py') def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) - with self.roundtrip(expected) as actual: - assert_equal(expected, actual) + with pytest.warns(FutureWarning): + # TODO: make it possible to write invalid netCDF files from xarray + # without a warning + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535') def test_cross_engine_read_write_netcdf4(self): @@ -1905,25 +1882,24 @@ def test_dump_encodings_h5py(self): assert actual.x.encoding['compression_opts'] is None -# tests pending h5netcdf fix -@unittest.skip -class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): - autoclose = True - - @pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio']) def readengine(request): return request.param -@pytest.fixture(params=[1, 100]) +@pytest.fixture(params=[1, 20]) def nfiles(request): return request.param -@pytest.fixture(params=[True, False]) -def autoclose(request): - return request.param +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize @pytest.fixture(params=[True, False]) @@ -1946,8 +1922,8 @@ def skip_if_not_engine(engine): pytest.importorskip(engine) -def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, - chunks): +def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, + file_cache_maxsize): # skip certain combinations skip_if_not_engine(readengine) @@ -1955,9 +1931,6 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, if not has_dask and parallel: pytest.skip('parallel requires dask') - if readengine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose yet') - if ON_WINDOWS: pytest.skip('Skipping on Windows') @@ -1973,7 +1946,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, # check that calculation on opened datasets works properly actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, - autoclose=autoclose, chunks=chunks) + chunks=chunks) # check that using open_mfdataset returns dask arrays for variables assert isinstance(actual['foo'].data, dask_array_type) @@ -2172,22 +2145,20 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert isinstance(actual.foo.variable.data, da.Array) assert actual.foo.variable.data.chunks == \ ((5, 5),) assert_identical(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: assert actual.foo.variable.data.chunks == \ ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) + open_mfdataset('foo-bar-baz-*.nc') with raises_regex(ValueError, 'wild-card'): - open_mfdataset('http://some/remote/uri', autoclose=self.autoclose) + open_mfdataset('http://some/remote/uri') @requires_pathlib def test_open_mfdataset_pathlib(self): @@ -2198,8 +2169,7 @@ def test_open_mfdataset_pathlib(self): tmp2 = Path(tmp2) original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(original, actual) def test_attrs_mfdataset(self): @@ -2230,8 +2200,7 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess, - autoclose=self.autoclose) as actual: + with open_mfdataset(tmp, preprocess=preprocess) as actual: assert_identical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -2241,8 +2210,7 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_save_mfdataset_invalid(self): @@ -2268,15 +2236,14 @@ def test_save_mfdataset_pathlib_roundtrip(self): tmp1 = Path(tmp1) tmp2 = Path(tmp2) save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) @@ -2286,8 +2253,7 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: assert_identical(data, actual) def test_open_dataset(self): @@ -2314,8 +2280,7 @@ def test_open_single_dataset(self): {'baz': [100]}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset([tmp], concat_dim=dim, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp], concat_dim=dim) as actual: assert_identical(expected, actual) def test_dask_roundtrip(self): @@ -2334,10 +2299,10 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): @@ -2355,41 +2320,22 @@ def test_dataarray_compute(self): assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - def test_to_netcdf_compute_false_roundtrip(self): - from dask.delayed import Delayed - - original = create_test_data().chunk() - - with create_tmp_file() as tmp_file: - # dataset, path, **kwargs): - delayed_obj = self.save(original, tmp_file, compute=False) - assert isinstance(delayed_obj, Delayed) - delayed_obj.compute() - - with self.open(tmp_file) as actual: - assert_identical(original, actual) - def test_save_mfdataset_compute_false_roundtrip(self): from dask.delayed import Delayed original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] - with create_tmp_file() as tmp1: - with create_tmp_file() as tmp2: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], engine=self.engine, compute=False) assert isinstance(delayed_obj, Delayed) delayed_obj.compute() - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) -class DaskTestAutocloseTrue(DaskTest): - autoclose = True - - @requires_scipy_or_netCDF4 @requires_pydap class PydapTest(object): @@ -2500,8 +2446,7 @@ def test_write_store(self): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine='pynio', autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine='pynio', **kwargs) as ds: yield ds def save(self, dataset, path, **kwargs): @@ -2519,19 +2464,12 @@ def test_weakrefs(self): assert_identical(actual, expected) -class PyNioTestAutocloseTrue(PyNioTest): - autoclose = True - - @requires_pseudonetcdf @pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') class PseudoNetCDFFormatTest(object): - autoclose = True def open(self, path, **kwargs): - return open_dataset(path, engine='pseudonetcdf', - autoclose=self.autoclose, - **kwargs) + return open_dataset(path, engine='pseudonetcdf', **kwargs) @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, @@ -2548,7 +2486,6 @@ def test_ict_format(self): """ ictfile = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs={'format': 'ffi1001'}) stdattr = { 'fill_value': -9999.0, @@ -2646,7 +2583,6 @@ def test_ict_format_write(self): fmtkw = {'format': 'ffi1001'} expected = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: @@ -2659,7 +2595,6 @@ def test_uamiv_format_read(self): camxfile = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=True, backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, @@ -2687,7 +2622,6 @@ def test_uamiv_format_mfread(self): ['example.uamiv', 'example.uamiv'], engine='pseudonetcdf', - autoclose=True, concat_dim='TSTEP', backend_kwargs={'format': 'uamiv'}) @@ -2701,11 +2635,11 @@ def test_uamiv_format_mfread(self): data1 = np.array(['2002-06-03'], 'datetime64[ns]') data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable(('TSTEP',), data, - dict(bounds='time_bounds', - long_name=('synthesized time coordinate ' + - 'from SDATE, STIME, STEP ' + - 'global attributes'))) + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) actual = camxfile.variables['time'] assert_allclose(expected, actual) camxfile.close() @@ -2715,7 +2649,6 @@ def test_uamiv_format_write(self): expected = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, @@ -3312,32 +3245,6 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): assert_identical(original_da, loaded_da) -def test_pickle_reconstructor(): - - lines = ['foo bar spam eggs'] - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: - with open(tmp, 'w') as f: - f.writelines(lines) - - obj = PickleByReconstructionWrapper(open, tmp) - - assert obj.value.readlines() == lines - - p_obj = pickle.dumps(obj) - obj.value.close() # for windows - obj2 = pickle.loads(p_obj) - - assert obj2.value.readlines() == lines - - # roundtrip again to make sure we can fully restore the state - p_obj2 = pickle.dumps(obj2) - obj2.value.close() # for windows - obj3 = pickle.loads(p_obj2) - - assert obj3.value.readlines() == lines - - @requires_scipy_or_netCDF4 def test_no_warning_from_dask_effective_get(): with create_tmp_file() as tmpfile: diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py new file mode 100644 index 00000000000..591c981cd45 --- /dev/null +++ b/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,114 @@ +import pickle +import threading +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.file_manager import CachingFileManager +from xarray.backends.lru_cache import LRUCache + + +@pytest.fixture(params=[1, 2, 3, None]) +def file_cache(request): + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) + + +def test_file_manager_mock_write(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) + + manager = CachingFileManager( + opener, 'filename', lock=lock, cache=file_cache) + f = manager.acquire() + f.write('contents') + manager.close() + + assert not file_cache + opener.assert_called_once_with('filename') + mock_file.write.assert_called_once_with('contents') + mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) + + +def test_file_manager_write_consecutive(tmpdir, file_cache): + path1 = str(tmpdir.join('testing1.txt')) + path2 = str(tmpdir.join('testing2.txt')) + manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache) + manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache) + f1a = manager1.acquire() + f1a.write('foo') + f1a.flush() + f2 = manager2.acquire() + f2.write('bar') + f2.flush() + f1b = manager1.acquire() + f1b.write('baz') + assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b) + manager1.close() + manager2.close() + + with open(path1, 'r') as f: + assert f.read() == 'foobaz' + with open(path2, 'r') as f: + assert f.read() == 'bar' + + +def test_file_manager_write_concurrent(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobarbaz' + + +def test_file_manager_write_pickle(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f = manager.acquire() + f.write('foo') + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write('bar') + manager2.close() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +def test_file_manager_read(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + + with open(path, 'w') as f: + f.write('foobar') + + manager = CachingFileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == 'foobar' + manager.close() + + +def test_file_manager_invalid_kwargs(): + with pytest.raises(TypeError): + CachingFileManager(open, 'dummy', mode='w', invalid=True) diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py new file mode 100644 index 00000000000..5f83321802e --- /dev/null +++ b/xarray/tests/test_backends_locks.py @@ -0,0 +1,13 @@ +import threading + +from xarray.backends import locks + + +def test_threaded_lock(): + lock1 = locks._get_threaded_lock('foo') + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock('foo') + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock('bar') + assert lock1 is not lock3 diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 00000000000..03eb6dcf208 --- /dev/null +++ b/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,91 @@ +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + + assert cache['x'] == 1 + assert cache['y'] == 2 + assert len(cache) == 2 + assert dict(cache) == {'x': 1, 'y': 2} + assert list(cache.keys()) == ['x', 'y'] + assert list(cache.items()) == [('x', 1), ('y', 2)] + + cache['z'] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [('y', 2), ('z', 3)] + + +def test_trivial(): + cache = LRUCache(maxsize=0) + cache['x'] = 1 + assert len(cache) == 0 + + +def test_invalid(): + with pytest.raises(TypeError): + LRUCache(maxsize=None) + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + assert list(cache) == ['x', 'y'] + assert 'x' in cache # contains + assert list(cache) == ['y', 'x'] + assert cache['y'] == 2 # getitem + assert list(cache) == ['x', 'y'] + cache['x'] = 3 # setitem + assert list(cache.items()) == [('y', 2), ('x', 3)] + + +def test_del(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + del cache['x'] + assert dict(cache) == {'y': 2} + + +def test_on_evict(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache['x'] = 1 + cache['y'] = 2 + on_evict.assert_called_once_with('x', 1) + + +def test_on_evict_trivial(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache['x'] = 1 + on_evict.assert_called_once_with('x', 1) + + +def test_resize(): + cache = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache['w'] = 0 + cache['x'] = 1 + cache['y'] = 2 + assert list(cache.items()) == [('x', 1), ('y', 2)] + cache.maxsize = 10 + cache['z'] = 3 + assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache.maxsize = 1 + assert list(cache.items()) == [('z', 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9bee965392b..89704653e92 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -64,8 +64,8 @@ def create_test_multiindex(): class InaccessibleVariableDataStore(backends.InMemoryDataStore): - def __init__(self, writer=None): - super(InaccessibleVariableDataStore, self).__init__(writer) + def __init__(self): + super(InaccessibleVariableDataStore, self).__init__() self._indexvars = set() def store(self, variables, *args, **kwargs): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 32035afdc57..7c77a62d3c9 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -15,12 +15,13 @@ from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of +import numpy as np import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, create_tmp_geotiff) from xarray.tests.test_dataset import create_test_data -from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, @@ -33,6 +34,11 @@ da = pytest.importorskip('dask.array') +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join('testfile.nc')) + + ENGINES = [] if has_scipy: ENGINES.append('scipy') @@ -45,81 +51,69 @@ 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], 'h5netcdf': ['NETCDF4']} -TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4'] +ENGINES_AND_FORMATS = [ + ('netcdf4', 'NETCDF3_CLASSIC'), + ('netcdf4', 'NETCDF4_CLASSIC'), + ('netcdf4', 'NETCDF4'), + ('h5netcdf', 'NETCDF4'), + ('scipy', 'NETCDF3_64BIT'), +] -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ['netcdf4']) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop, - engine, autoclose, nc_format): - monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') - - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + loop, tmp_netcdf_filename, engine, nc_format): - original = create_test_data().chunk(chunks) - original.to_netcdf(filename, engine=engine, format=nc_format) - - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) + if engine not in ENGINES: + pytest.skip('engine not available') + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ENGINES) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose, - nc_format): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - if engine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose') + original = create_test_data().chunk(chunks) - if nc_format not in NC_FORMATS[engine]: - pytest.skip('invalid format for engine') + if engine == 'scipy': + with pytest.raises(NotImplementedError): + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + return - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - original = create_test_data() - original.to_netcdf(filename, engine=engine, format=nc_format) - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format): + if engine not in ENGINES: + pytest.skip('engine not available') -@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy']) -def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine): chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data() + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - original = create_test_data().chunk(chunks) + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, + engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - with raises_regex(NotImplementedError, 'distributed'): - original.to_netcdf(filename, engine=engine) @requires_zarr diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index aed96f1acb6..4441375a1b1 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -4,6 +4,7 @@ import xarray from xarray.core.options import OPTIONS +from xarray.backends.file_manager import FILE_CACHE def test_invalid_option_raises(): @@ -11,6 +12,38 @@ def test_invalid_option_raises(): xarray.set_options(not_a_valid_options=True) +def test_display_width(): + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join(): + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join='invalid') + with xarray.set_options(arithmetic_join='exact'): + assert OPTIONS['arithmetic_join'] == 'exact' + + +def test_enable_cftimeindex(): + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS['enable_cftimeindex'] + + +def test_file_cache_maxsize(): + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + def test_nested_options(): original = OPTIONS['display_width'] with xarray.set_options(display_width=1):