From 81f9b94a9c4747640b3c82132a670dd4f73741ad Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 15 Nov 2016 15:17:54 -0800 Subject: [PATCH] Disable all caching on xarray.Variable This is a follow-up to generalize the changes from #1024: - Caching and copy-on-write behavior has been moved to separate array classes that are explicitly used in `open_dataset` to wrap arrays loaded from disk (if `cache=True`). - Dask specific logic has been removed from the caching/loading logic on `xarray.Variable`. - Pickle no longer caches automatically under any circumstances. Still needs tests for the `cache` argument to `open_dataset`, but everything else seems to be working. --- doc/whats-new.rst | 14 ++++--- xarray/backends/api.py | 32 +++++++++++++-- xarray/backends/h5netcdf_.py | 5 ++- xarray/backends/netCDF4_.py | 6 +-- xarray/backends/pynio_.py | 6 +-- xarray/backends/scipy_.py | 6 +-- xarray/core/common.py | 8 ++++ xarray/core/dataset.py | 21 +++++----- xarray/core/indexing.py | 40 +++++++++++++++++++ xarray/core/utils.py | 7 ++++ xarray/core/variable.py | 63 ++++++++++-------------------- xarray/test/test_backends.py | 69 +++++++++++++++++++++++---------- xarray/test/test_conventions.py | 8 ++-- xarray/test/test_indexing.py | 42 ++++++++++++++++++++ xarray/test/test_utils.py | 9 +++++ 15 files changed, 239 insertions(+), 97 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 31048f333ab..9067235abc1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,12 +25,14 @@ Breaking changes merges will now succeed in cases that previously raised ``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the previous default. -- Pickling an xarray object based on the dask backend, or reading its - :py:meth:`values` property, won't automatically convert the array from dask - to numpy in the original object anymore. - If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or - :py:class:`~xarray.Dataset`, its values are eagerly computed and cached, - but only if it's used to index a dim (e.g. it's used for alignment). +- Pickling an xarray object or reading its :py:attr:`~DataArray.values` + property no longer always caches values in a NumPy array. Caching + of ``.values`` read from netCDF files on disk is still the default when + :py:func:`open_dataset` is called with ``cache=True``. + By `Guido Imperiale `_ and + `Stephan Hoyer `_. +- Coordinates used to index a dimension are now loaded eagerly into + :py:class:`pandas.Index` objects, instead of loading the values lazily. By `Guido Imperiale `_. Deprecations diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ae12b0cc74a..8ae541ede63 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -13,13 +13,16 @@ from .. import backends, conventions from .common import ArrayWriter +from ..core import indexing from ..core.combine import auto_combine from ..core.utils import close_on_error, is_remote_uri +from ..core.variable import Variable, IndexVariable from ..core.pycompat import basestring DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' + def _get_default_engine(path, allow_remote=False): if allow_remote and is_remote_uri(path): # pragma: no cover try: @@ -117,10 +120,20 @@ def check_attr(name, value): check_attr(k, v) +def _protect_dataset_variables_inplace(dataset, cache): + for name, variable in dataset.variables.items(): + if name not in variable.dims: + # no need to protect IndexVariable objects + data = indexing.CopyOnWriteArray(variable._data) + if cache: + data = indexing.MemoryCachedArray(data) + variable._data = data + + def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None): """Load and decode a dataset from a file or file-like object. Parameters @@ -162,14 +175,22 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, 'netcdf4'. chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask - arrays. This is an experimental feature; see the documentation for more - details. + arrays. ``chunks={}`` loads the dataset with dask using a single + chunk for all arrays. This is an experimental feature; see the + documentation for more details. lock : False, True or threading.Lock, optional If chunks is provided, this argument is passed on to :py:func:`dask.array.from_array`. By default, a per-variable lock is used when reading data from netCDF files with the netcdf4 and h5netcdf engines to avoid issues with concurrent access when using dask's multithreaded backend. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. drop_variables: string or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -190,12 +211,17 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, concat_characters = False decode_coords = False + if cache is None: + cache = chunks is not None + def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( store, mask_and_scale=mask_and_scale, decode_times=decode_times, concat_characters=concat_characters, decode_coords=decode_coords, drop_variables=drop_variables) + _protect_dataset_variables_inplace(ds, cache) + if chunks is not None: try: from dask.base import tokenize diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 8796e276994..38eee9c23d6 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -5,7 +5,8 @@ from .. import Variable from ..core import indexing -from ..core.utils import FrozenOrderedDict, close_on_error, Frozen +from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen, + NoPickleMixin) from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict from .common import WritableCFDataStore @@ -37,7 +38,7 @@ def _read_attributes(h5netcdf_var): lsd_okay=False, backend='h5netcdf') -class H5NetCDFStore(WritableCFDataStore): +class H5NetCDFStore(WritableCFDataStore, NoPickleMixin): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b0acb31ed45..70760030298 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,7 +9,7 @@ from .. import Variable from ..conventions import pop_to, cf_encoder from ..core import indexing -from ..core.utils import (FrozenOrderedDict, NDArrayMixin, +from ..core.utils import (FrozenOrderedDict, NDArrayMixin, NoPickleMixin, close_on_error, is_remote_uri) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 @@ -25,7 +25,7 @@ '|': 'native'} -class BaseNetCDF4Array(NDArrayMixin): +class BaseNetCDF4Array(NDArrayMixin, NoPickleMixin): def __init__(self, array, is_remote=False): self.array = array self.is_remote = is_remote @@ -176,7 +176,7 @@ def _extract_nc4_encoding(variable, raise_on_invalid=False, lsd_okay=True, return encoding -class NetCDF4DataStore(WritableCFDataStore): +class NetCDF4DataStore(WritableCFDataStore, NoPickleMixin): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 7ea7f21b651..c545065df93 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -4,13 +4,13 @@ import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin +from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin, NoPickleMixin from ..core import indexing from .common import AbstractDataStore -class NioArrayWrapper(NDArrayMixin): +class NioArrayWrapper(NDArrayMixin, NoPickleMixin): def __init__(self, array, ds): self.array = array self._ds = ds # make an explicit reference because pynio uses weakrefs @@ -25,7 +25,7 @@ def __getitem__(self, key): return self.array[key] -class NioDataStore(AbstractDataStore): +class NioDataStore(AbstractDataStore, NoPickleMixin): """Store for accessing datasets via PyNIO """ def __init__(self, filename, mode='r'): diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 200834d2f2c..4aba4366a5c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -8,7 +8,7 @@ from .. import Variable from ..core.pycompat import iteritems, basestring, OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict +from ..core.utils import Frozen, FrozenOrderedDict, NoPickleMixin from ..core.indexing import NumpyIndexingAdapter from .common import WritableCFDataStore @@ -29,7 +29,7 @@ def _decode_attrs(d): for (k, v) in iteritems(d)) -class ScipyArrayWrapper(NumpyIndexingAdapter): +class ScipyArrayWrapper(NumpyIndexingAdapter, NoPickleMixin): def __init__(self, netcdf_file, variable_name): self.netcdf_file = netcdf_file self.variable_name = variable_name @@ -57,7 +57,7 @@ def __getitem__(self, key): return data -class ScipyDataStore(WritableCFDataStore): +class ScipyDataStore(WritableCFDataStore, NoPickleMixin): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a diff --git a/xarray/core/common.py b/xarray/core/common.py index 5ac9994ee8c..7afa20a6864 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -248,6 +248,14 @@ def __dir__(self): if isinstance(item, basestring)] return sorted(set(dir(type(self)) + extra_attrs)) + def __getstate__(self): + """Get this object's state for pickling""" + # we need a custom method to avoid + + # self.__dict__ is the default pickle object, we don't need to + # implement our own __setstate__ method to make pickle work + return self.__dict__ + class SharedMethodsMixin(object): """Shared methods for Dataset, DataArray and Variable.""" diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 282409cf4ca..3ea04495829 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -260,17 +260,12 @@ def load_store(cls, store, decoder=None): return obj def __getstate__(self): - """Load data in-memory before pickling (except for Dask data)""" - for v in self.variables.values(): - if not isinstance(v.data, dask_array_type): - v.load() + """Get this object's state for pickling""" + # we need a custom method to avoid # self.__dict__ is the default pickle object, we don't need to # implement our own __setstate__ method to make pickle work - state = self.__dict__.copy() - # throw away any references to datastores in the pickle - state['_file_obj'] = None - return state + return self.__dict__ @property def variables(self): @@ -331,9 +326,8 @@ def load(self): working with many file objects on disk. """ # access .data to coerce everything to numpy or dask arrays - all_data = dict((k, v.data) for k, v in self.variables.items()) - lazy_data = dict((k, v) for k, v in all_data.items() - if isinstance(v, dask_array_type)) + lazy_data = {k: v._data for k, v in self.variables.items() + if isinstance(v._data, dask_array_type)} if lazy_data: import dask.array as da @@ -343,6 +337,11 @@ def load(self): for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data + # load everything else sequentially + for k, v in self.variables.items(): + if k not in lazy_data: + v.load() + return self def compute(self): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8cbb91ebd4f..6c4d47ade10 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -403,6 +403,46 @@ def __repr__(self): (type(self).__name__, self.array, self.key)) +class CopyOnWriteArray(utils.NDArrayMixin): + def __init__(self, array): + self.array = array + self._copied = False + + def _ensure_copied(self): + if not self._copied: + self.array = np.array(self.array) + self._copied = True + + def __array__(self, dtype=None): + return np.asarray(self.array, dtype=dtype) + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __setitem__(self, key, value): + self._ensure_copied() + self.array[key] = value + + +class MemoryCachedArray(utils.NDArrayMixin): + def __init__(self, array): + self.array = array + + def _ensure_cached(self): + if not isinstance(self.array, np.ndarray): + self.array = np.asarray(self.array) + + def __array__(self, dtype=None): + self._ensure_cached() + return np.asarray(self.array, dtype=dtype) + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __setitem__(self, key, value): + self.array[key] = value + + def orthogonally_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ff7ccbc2670..5a701680f4d 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -420,6 +420,13 @@ def __repr__(self): return '%s(array=%r)' % (type(self).__name__, self.array) +class NoPickleMixin(object): + def __getstate__(self): + raise TypeError( + 'cannot pickle objects of type %r: call .compute() or .load() ' + 'to load data into memory first.' % type(self)) + + @contextlib.contextmanager def close_on_error(f): """Context manager to ensure that a file opened by xarray is closed if an diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1b6f5b55dda..47746e07cad 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -142,10 +142,12 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, timedelta): data = np.timedelta64(getattr(data, 'value', data), 'ns') - if (not hasattr(data, 'dtype') or not hasattr(data, 'shape') or + if (not hasattr(data, 'dtype') or not isinstance(data.dtype, np.dtype) or + not hasattr(data, 'shape') or isinstance(data, (np.string_, np.unicode_, np.datetime64, np.timedelta64))): # data must be ndarray-like + # don't allow non-numpy dtypes (e.g., categories) data = np.asarray(data) # we don't want nested self-described arrays @@ -260,7 +262,9 @@ def nbytes(self): @property def _in_memory(self): - return isinstance(self._data, (np.ndarray, PandasIndexAdapter)) + return (isinstance(self._data, (np.ndarray, PandasIndexAdapter)) or + (isinstance(self._data, indexing.MemoryCachedArray) and + isinstance(self._data.array, np.ndarray))) @property def data(self): @@ -277,22 +281,6 @@ def data(self, data): "replacement data must match the Variable's shape") self._data = data - def _data_cast(self): - if isinstance(self._data, (np.ndarray, PandasIndexAdapter)): - return self._data - else: - return np.asarray(self._data) - - def _data_cached(self): - """Load data into memory and return it. - Do not cache dask arrays automatically; that should - require an explicit load() call. - """ - new_data = self._data_cast() - if not isinstance(self._data, dask_array_type): - self._data = new_data - return new_data - @property def _indexable_data(self): return orthogonally_indexable(self._data) @@ -305,7 +293,8 @@ def load(self): because all xarray functions should either work on deferred data or load data automatically. """ - self._data = self._data_cast() + if not isinstance(self._data, np.ndarray): + self._data = np.asarray(self._data) return self def compute(self): @@ -320,19 +309,10 @@ def compute(self): new = self.copy(deep=False) return new.load() - def __getstate__(self): - """Always cache data as an in-memory array before pickling - (with the exception of dask backend)""" - if not isinstance(self._data, dask_array_type): - self._data_cached() - # self.__dict__ is the default pickle object, we don't need to - # implement our own __setstate__ method to make pickle work - return self.__dict__ - @property def values(self): """The variable's data as a numpy.ndarray""" - return _as_array_or_item(self._data_cached()) + return _as_array_or_item(self._data) @values.setter def values(self, values): @@ -425,7 +405,7 @@ def __setitem__(self, key, value): 'assign to this variable, you must first load it ' 'into memory explicitly using the .load_data() ' 'method or accessing its .values attribute.') - data = orthogonally_indexable(self._data_cached()) + data = orthogonally_indexable(self._data) data[key] = value @property @@ -462,7 +442,7 @@ def copy(self, deep=True): the new object. Dimensions, attributes and encodings are always copied. """ if (deep and not isinstance(self.data, dask_array_type) - and not isinstance(self._data, PandasIndexAdapter)): + and not isinstance(self._data, PandasIndexAdapter)): # pandas.Index objects are immutable # dask arrays don't have a copy method # https://github.com/blaze/dask/issues/911 @@ -1144,15 +1124,16 @@ def concat(cls, variables, dim='concat_dim', positions=None, raise TypeError('IndexVariable.concat requires that all input ' 'variables be IndexVariable objects') - arrays = [v._data_cached().array for v in variables] + indexes = [v._data.array for v in variables] - if not arrays: + if not indexes: data = [] else: - data = arrays[0].append(arrays[1:]) + data = indexes[0].append(indexes[1:]) if positions is not None: - indices = nputils.inverse_permutation(np.concatenate(positions)) + indices = nputils.inverse_permutation( + np.concatenate(positions)) data = data.take(indices) attrs = OrderedDict(first_var.attrs) @@ -1167,13 +1148,11 @@ def concat(cls, variables, dim='concat_dim', positions=None, def copy(self, deep=True): """Returns a copy of this object. - If `deep=True`, the values array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. + `deep` is ignored since data is stored in the form of pandas.Index, + which is already immutable. Dimensions, attributes and encodings are + always copied. """ - # there is no need to copy the index values here even if deep=True - # since pandas.Index objects are immutable - data = PandasIndexAdapter(self) if deep else self._data - return type(self)(self.dims, data, self._attrs, + return type(self)(self.dims, self._data, self._attrs, self._encoding, fastpath=True) def _data_equals(self, other): @@ -1190,7 +1169,7 @@ def to_index(self): # n.b. creating a new pandas.Index from an old pandas.Index is # basically free as pandas.Index objects are immutable assert self.ndim == 1 - index = self._data_cached().array + index = self._data.array if isinstance(index, pd.MultiIndex): # set default names for multi-index unnamed levels so that # we can safely rename dimension / coordinate later diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 65e0f3d51ac..c9200907c7e 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -163,18 +163,12 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) for v in computed.variables.values(): self.assertTrue(v._in_memory) @@ -282,10 +276,40 @@ def test_orthogonal_indexing(self): actual = on_disk.isel(**indexers) self.assertDatasetAllClose(expected, actual) + +class PickleNotSupportedMixin(object): + + def test_pickle(self): + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip(expected) as on_disk: + with self.assertRaisesRegexp( + TypeError, 'load data into memory first'): + pickle.dumps(on_disk) + computed_ds = on_disk.compute() + unpickled_ds = pickle.loads(pickle.dumps(computed_ds)) + self.assertDatasetIdentical(expected, computed_ds) + self.assertDatasetIdentical(expected, unpickled_ds) + + with self.assertRaisesRegexp( + TypeError, 'load data into memory first'): + pickle.dumps(on_disk['foo']) + computed_array = on_disk['foo'].compute() + unpickled_array = pickle.loads(pickle.dumps(computed_array)) + self.assertDatasetIdentical(expected['foo'], computed_array) + self.assertDatasetIdentical(expected['foo'], unpickled_array) + + +class PickleSupportedMixin(object): + def test_pickle(self): - on_disk = open_example_dataset('bears.nc') - unpickled = pickle.loads(pickle.dumps(on_disk)) - self.assertDatasetIdentical(on_disk, unpickled) + # this should work for dask arrays, unlike most real data stores + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip(expected) as roundtripped: + unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) + self.assertDatasetIdentical(expected, unpickled_ds) + + unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) + self.assertDatasetIdentical(expected['foo'], unpickled_array) class CFEncodedDataTest(DatasetIOTestCases): @@ -443,7 +467,7 @@ def create_tmp_file(suffix='.nc'): shutil.rmtree(temp_dir) -class BaseNetCDF4Test(CFEncodedDataTest): +class BaseNetCDF4Test(CFEncodedDataTest, PickleNotSupportedMixin): def test_open_group(self): # Create a netCDF file with a dataset stored within a group with create_tmp_file() as tmp_file: @@ -642,7 +666,7 @@ def test_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, TestCase): +class NetCDF4DataTest(BaseNetCDF4Test, PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -687,7 +711,7 @@ def test_unsorted_index_raises(self): @requires_netCDF4 @requires_dask -class NetCDF4ViaDaskDataTest(NetCDF4DataTest): +class NetCDF4ViaDaskDataTest(NetCDF4DataTest, PickleNotSupportedMixin): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}): with NetCDF4DataTest.roundtrip( @@ -753,7 +777,8 @@ def test_netcdf3_endianness(self): @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, + PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -771,7 +796,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_netCDF4 -class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, + PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -789,7 +815,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, + PickleNotSupportedMixin, TestCase): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed @@ -841,7 +868,7 @@ def test_cross_engine_read_write_netcdf3(self): @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): +class H5NetCDFDataTest(BaseNetCDF4Test, PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -892,7 +919,7 @@ def test_read_byte_attrs_as_unicode(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, DatasetIOTestCases): +class DaskTest(TestCase, PickleSupportedMixin, DatasetIOTestCases): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -1041,6 +1068,7 @@ def test_dataarray_compute(self): self.assertTrue(computed._in_memory) self.assertDataArrayAllClose(actual, computed) + @requires_scipy_or_netCDF4 @requires_pydap class PydapTest(TestCase): @@ -1083,7 +1111,8 @@ def test_dask(self): @requires_scipy @requires_pynio -class TestPyNio(CFEncodedDataTest, Only32BitTypes, TestCase): +class TestPyNio(CFEncodedDataTest, PickleNotSupportedMixin, Only32BitTypes, + TestCase): def test_write_store(self): # pynio is read-only for now pass diff --git a/xarray/test/test_conventions.py b/xarray/test/test_conventions.py index 265b2a93b0f..e3c2c2bd0ff 100644 --- a/xarray/test/test_conventions.py +++ b/xarray/test/test_conventions.py @@ -9,7 +9,7 @@ from xarray import conventions, Variable, Dataset, open_dataset from xarray.core import utils, indexing from . import TestCase, requires_netCDF4, unittest -from .test_backends import CFEncodedDataTest +from .test_backends import CFEncodedDataTest, PickleSupportedMixin from xarray.core.pycompat import iteritems from xarray.backends.memory import InMemoryDataStore from xarray.backends.common import WritableCFDataStore @@ -191,11 +191,11 @@ def test_cf_datetime(self): @requires_netCDF4 def test_decode_cf_datetime_overflow(self): - # checks for + # checks for # https://github.com/pydata/pandas/issues/14068 # https://github.com/pydata/xarray/issues/975 - from datetime import datetime + from datetime import datetime units = 'days since 2000-01-01 00:00:00' # date after 2262 and before 1678 @@ -620,7 +620,7 @@ def null_wrap(ds): @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): +class TestCFEncodedDataStore(CFEncodedDataTest, PickleSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() diff --git a/xarray/test/test_indexing.py b/xarray/test/test_indexing.py index ed61c9eb463..9d22b4f2c87 100644 --- a/xarray/test/test_indexing.py +++ b/xarray/test/test_indexing.py @@ -200,3 +200,45 @@ def test_lazily_indexed_array(self): actual = lazy[i][j] self.assertEqual(expected.shape, actual.shape) self.assertArrayEqual(expected, actual) + + +class TestCopyOnWriteArray(TestCase): + def test_setitem(self): + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + wrapped[:] = 0 + self.assertArrayEqual(original, np.arange(10)) + self.assertArrayEqual(wrapped, np.zeros(10)) + + def test_sub_array(self): + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + child = wrapped[:5] + self.assertIsInstance(child, indexing.CopyOnWriteArray) + child[:] = 0 + self.assertArrayEqual(original, np.arange(10)) + self.assertArrayEqual(wrapped, np.arange(10)) + self.assertArrayEqual(child, np.zeros(5)) + + +class TestMemoryCachedArray(TestCase): + def test_wrapper(self): + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + self.assertArrayEqual(wrapped, np.arange(10)) + self.assertIsInstance(wrapped.array, np.ndarray) + + def test_sub_array(self): + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[:5] + self.assertIsInstance(child, indexing.MemoryCachedArray) + self.assertArrayEqual(child, np.arange(5)) + self.assertIsInstance(child.array, np.ndarray) + self.assertIsInstance(wrapped.array, indexing.LazilyIndexedArray) + + def test_setitem(self): + original = np.arange(10) + wrapped = indexing.MemoryCachedArray(original) + wrapped[:] = 0 + self.assertArrayEqual(original, np.zeros(10)) diff --git a/xarray/test/test_utils.py b/xarray/test/test_utils.py index 611b45f80d1..db0bf2d202f 100644 --- a/xarray/test/test_utils.py +++ b/xarray/test/test_utils.py @@ -1,6 +1,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import pickle +import pytest + import numpy as np import pandas as pd @@ -171,3 +174,9 @@ def test_hashable(self): self.assertTrue(utils.hashable(v)) for v in [[5, 6], ['seven', '8'], {9: 'ten'}]: self.assertFalse(utils.hashable(v)) + + +def test_no_pickle_mixin(): + obj = utils.NoPickleMixin() + with pytest.raises(TypeError): + pickle.dumps(obj)