diff --git a/ci/requirements-py27-cdat+pynio.yml b/ci/requirements-py27-cdat+pynio.yml index 53aafb058e1..5f98b9e1f6f 100644 --- a/ci/requirements-py27-cdat+pynio.yml +++ b/ci/requirements-py27-cdat+pynio.yml @@ -5,6 +5,7 @@ dependencies: - python=2.7 - cdat-lite - dask + - distributed - pytest - numpy - pandas>=0.15.0 diff --git a/ci/requirements-py27-netcdf4-dev.yml b/ci/requirements-py27-netcdf4-dev.yml index a64782de235..4ce193a2a82 100644 --- a/ci/requirements-py27-netcdf4-dev.yml +++ b/ci/requirements-py27-netcdf4-dev.yml @@ -3,6 +3,7 @@ dependencies: - python=2.7 - cython - dask + - distributed - h5py - pytest - numpy diff --git a/ci/requirements-py27-pydap.yml b/ci/requirements-py27-pydap.yml index 459f049c76a..e391eee514f 100644 --- a/ci/requirements-py27-pydap.yml +++ b/ci/requirements-py27-pydap.yml @@ -2,6 +2,7 @@ name: test_env dependencies: - python=2.7 - dask + - distributed - h5py - netcdf4 - pytest diff --git a/ci/requirements-py35.yml b/ci/requirements-py35.yml index 0f3b005ea6a..c6641598fca 100644 --- a/ci/requirements-py35.yml +++ b/ci/requirements-py35.yml @@ -3,6 +3,7 @@ dependencies: - python=3.5 - cython - dask + - distributed - h5py - matplotlib - netcdf4 diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 39ddc0284e8..2f3a3d4eae4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,12 +25,20 @@ Breaking changes merges will now succeed in cases that previously raised ``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the previous default. -- Pickling an xarray object based on the dask backend, or reading its - :py:meth:`values` property, won't automatically convert the array from dask - to numpy in the original object anymore. - If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or - :py:class:`~xarray.Dataset`, its values are eagerly computed and cached, - but only if it's used to index a dim (e.g. it's used for alignment). +- Reading :py:attr:`~DataArray.values` no longer always caches values in a NumPy + array :issue:`1128`. Caching of ``.values`` on variables read from netCDF + files on disk is still the default when :py:func:`open_dataset` is called with + ``cache=True``. + By `Guido Imperiale `_ and + `Stephan Hoyer `_. +- Pickling a ``Dataset`` or ``DataArray`` linked to a file on disk no longer + caches its values into memory before pickling :issue:`1128`. Instead, pickle + stores file paths and restores objects by reopening file references. This + enables preliminary, experimental use of xarray for opening files with + `dask.distributed `_. + By `Stephan Hoyer `_. +- Coordinates used to index a dimension are now loaded eagerly into + :py:class:`pandas.Index` objects, instead of loading the values lazily. By `Guido Imperiale `_. Deprecations diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ae12b0cc74a..bc2afa4b373 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -13,6 +13,7 @@ from .. import backends, conventions from .common import ArrayWriter +from ..core import indexing from ..core.combine import auto_combine from ..core.utils import close_on_error, is_remote_uri from ..core.pycompat import basestring @@ -20,6 +21,7 @@ DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' + def _get_default_engine(path, allow_remote=False): if allow_remote and is_remote_uri(path): # pragma: no cover try: @@ -46,6 +48,13 @@ def _get_default_engine(path, allow_remote=False): return engine +def _normalize_path(path): + if is_remote_uri(path): + return path + else: + return os.path.abspath(os.path.expanduser(path)) + + _global_lock = threading.Lock() @@ -117,10 +126,20 @@ def check_attr(name, value): check_attr(k, v) +def _protect_dataset_variables_inplace(dataset, cache): + for name, variable in dataset.variables.items(): + if name not in variable.dims: + # no need to protect IndexVariable objects + data = indexing.CopyOnWriteArray(variable._data) + if cache: + data = indexing.MemoryCachedArray(data) + variable.data = data + + def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None): """Load and decode a dataset from a file or file-like object. Parameters @@ -162,14 +181,22 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, 'netcdf4'. chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask - arrays. This is an experimental feature; see the documentation for more - details. + arrays. ``chunks={}`` loads the dataset with dask using a single + chunk for all arrays. This is an experimental feature; see the + documentation for more details. lock : False, True or threading.Lock, optional If chunks is provided, this argument is passed on to :py:func:`dask.array.from_array`. By default, a per-variable lock is used when reading data from netCDF files with the netcdf4 and h5netcdf engines to avoid issues with concurrent access when using dask's multithreaded backend. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. drop_variables: string or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -190,12 +217,17 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, concat_characters = False decode_coords = False + if cache is None: + cache = chunks is None + def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( store, mask_and_scale=mask_and_scale, decode_times=decode_times, concat_characters=concat_characters, decode_coords=decode_coords, drop_variables=drop_variables) + _protect_dataset_variables_inplace(ds, cache) + if chunks is not None: try: from dask.base import tokenize @@ -226,6 +258,17 @@ def maybe_decode_store(store, lock=False): if isinstance(filename_or_obj, backends.AbstractDataStore): store = filename_or_obj elif isinstance(filename_or_obj, basestring): + + if (isinstance(filename_or_obj, bytes) and + filename_or_obj.startswith(b'\x89HDF')): + raise ValueError('cannot read netCDF4/HDF5 file images') + elif (isinstance(filename_or_obj, bytes) and + filename_or_obj.startswith(b'CDF')): + # netCDF3 file images are handled by scipy + pass + elif isinstance(filename_or_obj, basestring): + filename_or_obj = _normalize_path(filename_or_obj) + if filename_or_obj.endswith('.gz'): if engine is not None and engine != 'scipy': raise ValueError('can only read gzipped netCDF files with ' @@ -274,7 +317,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None): """ Opens an DataArray from a netCDF file containing a single data variable. @@ -328,6 +371,13 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, used when reading data from netCDF files with the netcdf4 and h5netcdf engines to avoid issues with concurrent access when using dask's multithreaded backend. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. drop_variables: string or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -349,7 +399,7 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, dataset = open_dataset(filename_or_obj, group, decode_cf, mask_and_scale, decode_times, concat_characters, decode_coords, engine, - chunks, lock, drop_variables) + chunks, lock, cache, drop_variables) if len(dataset.data_vars) != 1: raise ValueError('Given file dataset contains more than one data ' @@ -494,8 +544,10 @@ def to_netcdf(dataset, path=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) - elif engine is None: - engine = _get_default_engine(path) + else: + if engine is None: + engine = _get_default_engine(path) + path = _normalize_path(path) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index bf85930c8df..208611829ee 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -235,3 +235,22 @@ def store(self, variables, attributes, check_encoding_set=frozenset()): cf_variables, cf_attrs = cf_encoder(variables, attributes) AbstractWritableDataStore.store(self, cf_variables, cf_attrs, check_encoding_set) + + +class DataStorePickleMixin(object): + """Subclasses must define `ds`, `_opener` and `_mode` attributes. + + Do not subclass this class: it is not part of xarray's external API. + """ + + def __getstate__(self): + state = self.__dict__.copy() + del state['ds'] + if self._mode == 'w': + # file has already been created, don't override when restoring + state['_mode'] = 'a' + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.ds = self._opener(mode=self._mode) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 8796e276994..76582cfd72e 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,7 +8,7 @@ from ..core.utils import FrozenOrderedDict, close_on_error, Frozen from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict -from .common import WritableCFDataStore +from .common import WritableCFDataStore, DataStorePickleMixin from .netCDF4_ import (_nc4_group, _nc4_values_and_dtype, _extract_nc4_encoding, BaseNetCDF4Array) @@ -37,24 +37,32 @@ def _read_attributes(h5netcdf_var): lsd_okay=False, backend='h5netcdf') -class H5NetCDFStore(WritableCFDataStore): +def _open_h5netcdf_group(filename, mode, group): + import h5netcdf.legacyapi + ds = h5netcdf.legacyapi.Dataset(filename, mode=mode) + with close_on_error(ds): + return _nc4_group(ds, group, mode) + + +class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, writer=None): - import h5netcdf.legacyapi if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - ds = h5netcdf.legacyapi.Dataset(filename, mode=mode) - with close_on_error(ds): - self.ds = _nc4_group(ds, group, mode) + opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, + group=group) + self.ds = opener() self.format = format + self._opener = opener self._filename = filename + self._mode = mode super(H5NetCDFStore, self).__init__(writer) - def open_store_variable(self, var): + def open_store_variable(self, name, var): dimensions = var.dimensions - data = indexing.LazilyIndexedArray(BaseNetCDF4Array(var)) + data = indexing.LazilyIndexedArray(BaseNetCDF4Array(name, self)) attrs = _read_attributes(var) # netCDF4 specific encoding @@ -69,7 +77,7 @@ def open_store_variable(self, var): return Variable(dimensions, data, attrs, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(v)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) for k, v in iteritems(self.ds.variables)) def get_attrs(self): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b0acb31ed45..9b05a41925d 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -1,8 +1,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import operator -from functools import partial import numpy as np @@ -13,7 +13,7 @@ close_on_error, is_remote_uri) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 -from .common import WritableCFDataStore, robust_getitem +from .common import WritableCFDataStore, robust_getitem, DataStorePickleMixin from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable, maybe_convert_to_char_array) @@ -26,9 +26,13 @@ class BaseNetCDF4Array(NDArrayMixin): - def __init__(self, array, is_remote=False): - self.array = array - self.is_remote = is_remote + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + + @property + def array(self): + return self.datastore.ds.variables[self.variable_name] @property def dtype(self): @@ -42,13 +46,9 @@ def dtype(self): class NetCDF4ArrayWrapper(BaseNetCDF4Array): - def __init__(self, array, is_remote=False): - self.array = array - self.is_remote = is_remote - def __getitem__(self, key): - if self.is_remote: # pragma: no cover - getitem = partial(robust_getitem, catch=RuntimeError) + if self.datastore.is_remote: # pragma: no cover + getitem = functools.partial(robust_getitem, catch=RuntimeError) else: getitem = operator.getitem @@ -176,31 +176,44 @@ def _extract_nc4_encoding(variable, raise_on_invalid=False, lsd_okay=True, return encoding -class NetCDF4DataStore(WritableCFDataStore): +def _open_netcdf4_group(filename, mode, group=None, **kwargs): + import netCDF4 as nc4 + + ds = nc4.Dataset(filename, mode=mode, **kwargs) + + with close_on_error(ds): + ds = _nc4_group(ds, group, mode) + + for var in ds.variables.values(): + # we handle masking and scaling ourselves + var.set_auto_maskandscale(False) + return ds + + +class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ def __init__(self, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False): - import netCDF4 as nc4 if format is None: format = 'NETCDF4' - ds = nc4.Dataset(filename, mode=mode, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - with close_on_error(ds): - self.ds = _nc4_group(ds, group, mode) + opener = functools.partial(_open_netcdf4_group, filename, mode=mode, + group=group, clobber=clobber, + diskless=diskless, persist=persist, + format=format) + self.ds = opener() self.format = format self.is_remote = is_remote_uri(filename) + self._opener = opener self._filename = filename + self._mode = 'a' if mode == 'w' else mode super(NetCDF4DataStore, self).__init__(writer) - def open_store_variable(self, var): - var.set_auto_maskandscale(False) + def open_store_variable(self, name, var): dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper( - var, self.is_remote)) + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) attributes = OrderedDict((k, var.getncattr(k)) for k in var.ncattrs()) _ensure_fill_value_valid(data, attributes) @@ -227,7 +240,7 @@ def open_store_variable(self, var): return Variable(dimensions, data, attributes, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(v)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) for k, v in iteritems(self.ds.variables)) def get_attrs(self): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 7ea7f21b651..075db5d4ccb 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,19 +1,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +import functools + import numpy as np from .. import Variable from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin from ..core import indexing -from .common import AbstractDataStore +from .common import AbstractDataStore, DataStorePickleMixin class NioArrayWrapper(NDArrayMixin): - def __init__(self, array, ds): - self.array = array - self._ds = ds # make an explicit reference because pynio uses weakrefs + + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + + @property + def array(self): + return self.datastore.ds.variables[self.variable_name] @property def dtype(self): @@ -25,19 +33,22 @@ def __getitem__(self, key): return self.array[key] -class NioDataStore(AbstractDataStore): +class NioDataStore(AbstractDataStore, DataStorePickleMixin): """Store for accessing datasets via PyNIO """ def __init__(self, filename, mode='r'): import Nio - self.ds = Nio.open_file(filename, mode=mode) + opener = functools.partial(Nio.open_file, filename, mode=mode) + self.ds = opener() + self._opener = opener + self._mode = mode - def open_store_variable(self, var): - data = indexing.LazilyIndexedArray(NioArrayWrapper(var, self.ds)) + def open_store_variable(self, name, var): + data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(v)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) for k, v in self.ds.variables.iteritems()) def get_attrs(self): diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 200834d2f2c..0113728f81c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from io import BytesIO import numpy as np @@ -11,7 +12,7 @@ from ..core.utils import Frozen, FrozenOrderedDict from ..core.indexing import NumpyIndexingAdapter -from .common import WritableCFDataStore +from .common import WritableCFDataStore, DataStorePickleMixin from .netcdf3 import (is_valid_nc3_name, encode_nc3_attr_value, encode_nc3_variable) @@ -30,17 +31,13 @@ def _decode_attrs(d): class ScipyArrayWrapper(NumpyIndexingAdapter): - def __init__(self, netcdf_file, variable_name): - self.netcdf_file = netcdf_file + def __init__(self, variable_name, datastore): + self.datastore = datastore self.variable_name = variable_name @property def array(self): - # We can't store the actual netcdf_variable object or its data array, - # because otherwise scipy complains about variables or files still - # referencing mmapped arrays when we try to close datasets without - # having read all data in the file. - return self.netcdf_file.variables[self.variable_name].data + return self.datastore.ds.variables[self.variable_name].data @property def dtype(self): @@ -52,12 +49,23 @@ def __getitem__(self, key): # Copy data if the source file is mmapped. This makes things consistent # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. - copy = self.netcdf_file.use_mmap + copy = self.datastore.ds.use_mmap data = np.array(data, dtype=self.dtype, copy=copy) return data -class ScipyDataStore(WritableCFDataStore): +def _open_scipy_netcdf(filename, mode, mmap, version): + import scipy.io + + if isinstance(filename, bytes) and filename.startswith(b'CDF'): + # it's a NetCDF3 bytestring + filename = BytesIO(filename) + + return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, + version=version) + + +class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -88,18 +96,17 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - # if filename is a NetCDF3 bytestring we store it in a StringIO - if (isinstance(filename_or_obj, basestring) and - filename_or_obj.startswith('CDF')): - # TODO: this check has the unfortunate side-effect that - # paths to files cannot start with 'CDF'. - filename_or_obj = BytesIO(filename_or_obj) - self.ds = scipy.io.netcdf_file( - filename_or_obj, mode=mode, mmap=mmap, version=version) + opener = functools.partial(_open_scipy_netcdf, + filename=filename_or_obj, + mode=mode, mmap=mmap, version=version) + self.ds = opener() + self._opener = opener + self._mode = mode + super(ScipyDataStore, self).__init__(writer) def open_store_variable(self, name, var): - return Variable(var.dimensions, ScipyArrayWrapper(self.ds, name), + return Variable(var.dimensions, ScipyArrayWrapper(name, self), _decode_attrs(var._attributes)) def get_variables(self): @@ -154,3 +161,11 @@ def close(self): def __exit__(self, type, value, tb): self.close() + + def __setstate__(self, state): + filename = state['_opener'].keywords['filename'] + if hasattr(filename, 'seek'): + # it's a file-like object + # seek to the start of the file so scipy can read it + filename.seek(0) + super(ScipyDataStore, self).__setstate__(state) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 282409cf4ca..b3bde8f7377 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -259,19 +259,6 @@ def load_store(cls, store, decoder=None): obj._file_obj = store return obj - def __getstate__(self): - """Load data in-memory before pickling (except for Dask data)""" - for v in self.variables.values(): - if not isinstance(v.data, dask_array_type): - v.load() - - # self.__dict__ is the default pickle object, we don't need to - # implement our own __setstate__ method to make pickle work - state = self.__dict__.copy() - # throw away any references to datastores in the pickle - state['_file_obj'] = None - return state - @property def variables(self): """Frozen dictionary of xarray.Variable objects constituting this @@ -331,9 +318,8 @@ def load(self): working with many file objects on disk. """ # access .data to coerce everything to numpy or dask arrays - all_data = dict((k, v.data) for k, v in self.variables.items()) - lazy_data = dict((k, v) for k, v in all_data.items() - if isinstance(v, dask_array_type)) + lazy_data = {k: v._data for k, v in self.variables.items() + if isinstance(v._data, dask_array_type)} if lazy_data: import dask.array as da @@ -343,6 +329,11 @@ def load(self): for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data + # load everything else sequentially + for k, v in self.variables.items(): + if k not in lazy_data: + v.load() + return self def compute(self): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8cbb91ebd4f..6c4d47ade10 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -403,6 +403,46 @@ def __repr__(self): (type(self).__name__, self.array, self.key)) +class CopyOnWriteArray(utils.NDArrayMixin): + def __init__(self, array): + self.array = array + self._copied = False + + def _ensure_copied(self): + if not self._copied: + self.array = np.array(self.array) + self._copied = True + + def __array__(self, dtype=None): + return np.asarray(self.array, dtype=dtype) + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __setitem__(self, key, value): + self._ensure_copied() + self.array[key] = value + + +class MemoryCachedArray(utils.NDArrayMixin): + def __init__(self, array): + self.array = array + + def _ensure_cached(self): + if not isinstance(self.array, np.ndarray): + self.array = np.asarray(self.array) + + def __array__(self, dtype=None): + self._ensure_cached() + return np.asarray(self.array, dtype=dtype) + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __setitem__(self, key, value): + self.array[key] = value + + def orthogonally_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e4960b6d72a..32c26bd02c1 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -6,6 +6,7 @@ import contextlib import functools import itertools +import os.path import re import warnings from collections import Mapping, MutableMapping, Iterable diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1b6f5b55dda..97a75019c25 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -142,10 +142,12 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, timedelta): data = np.timedelta64(getattr(data, 'value', data), 'ns') - if (not hasattr(data, 'dtype') or not hasattr(data, 'shape') or + if (not hasattr(data, 'dtype') or not isinstance(data.dtype, np.dtype) or + not hasattr(data, 'shape') or isinstance(data, (np.string_, np.unicode_, np.datetime64, np.timedelta64))): # data must be ndarray-like + # don't allow non-numpy dtypes (e.g., categories) data = np.asarray(data) # we don't want nested self-described arrays @@ -260,7 +262,9 @@ def nbytes(self): @property def _in_memory(self): - return isinstance(self._data, (np.ndarray, PandasIndexAdapter)) + return (isinstance(self._data, (np.ndarray, PandasIndexAdapter)) or + (isinstance(self._data, indexing.MemoryCachedArray) and + isinstance(self._data.array, np.ndarray))) @property def data(self): @@ -277,22 +281,6 @@ def data(self, data): "replacement data must match the Variable's shape") self._data = data - def _data_cast(self): - if isinstance(self._data, (np.ndarray, PandasIndexAdapter)): - return self._data - else: - return np.asarray(self._data) - - def _data_cached(self): - """Load data into memory and return it. - Do not cache dask arrays automatically; that should - require an explicit load() call. - """ - new_data = self._data_cast() - if not isinstance(self._data, dask_array_type): - self._data = new_data - return new_data - @property def _indexable_data(self): return orthogonally_indexable(self._data) @@ -305,7 +293,8 @@ def load(self): because all xarray functions should either work on deferred data or load data automatically. """ - self._data = self._data_cast() + if not isinstance(self._data, np.ndarray): + self._data = np.asarray(self._data) return self def compute(self): @@ -320,19 +309,10 @@ def compute(self): new = self.copy(deep=False) return new.load() - def __getstate__(self): - """Always cache data as an in-memory array before pickling - (with the exception of dask backend)""" - if not isinstance(self._data, dask_array_type): - self._data_cached() - # self.__dict__ is the default pickle object, we don't need to - # implement our own __setstate__ method to make pickle work - return self.__dict__ - @property def values(self): """The variable's data as a numpy.ndarray""" - return _as_array_or_item(self._data_cached()) + return _as_array_or_item(self._data) @values.setter def values(self, values): @@ -425,7 +405,7 @@ def __setitem__(self, key, value): 'assign to this variable, you must first load it ' 'into memory explicitly using the .load_data() ' 'method or accessing its .values attribute.') - data = orthogonally_indexable(self._data_cached()) + data = orthogonally_indexable(self._data) data[key] = value @property @@ -461,14 +441,17 @@ def copy(self, deep=True): If `deep=True`, the data array is loaded into memory and copied onto the new object. Dimensions, attributes and encodings are always copied. """ - if (deep and not isinstance(self.data, dask_array_type) - and not isinstance(self._data, PandasIndexAdapter)): - # pandas.Index objects are immutable - # dask arrays don't have a copy method - # https://github.com/blaze/dask/issues/911 - data = self.data.copy() - else: - data = self._data + data = self._data + + if isinstance(data, indexing.MemoryCachedArray): + # don't share caching between copies + data = indexing.MemoryCachedArray(data.array) + + if deep and not isinstance( + data, (dask_array_type, PandasIndexAdapter)): + # pandas.Index and dask.array objects are immutable + data = np.array(data) + # note: # dims is already an immutable tuple # attributes and encoding will be copied when the new Array is created @@ -701,7 +684,7 @@ def transpose(self, *dims): if len(dims) == 0: dims = self.dims[::-1] axes = self.get_axis_num(dims) - if len(dims) < 2: # no need to transpose if only one dimension + if len(dims) < 2: # no need to transpose if only one dimension return self.copy(deep=False) data = ops.transpose(self.data, axes) return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) @@ -1104,6 +1087,10 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): if not isinstance(self._data, PandasIndexAdapter): self._data = PandasIndexAdapter(self._data) + def load(self): + # data is already loaded into memory for IndexVariable + return self + @Variable.data.setter def data(self, data): Variable.data.fset(self, data) @@ -1144,15 +1131,16 @@ def concat(cls, variables, dim='concat_dim', positions=None, raise TypeError('IndexVariable.concat requires that all input ' 'variables be IndexVariable objects') - arrays = [v._data_cached().array for v in variables] + indexes = [v._data.array for v in variables] - if not arrays: + if not indexes: data = [] else: - data = arrays[0].append(arrays[1:]) + data = indexes[0].append(indexes[1:]) if positions is not None: - indices = nputils.inverse_permutation(np.concatenate(positions)) + indices = nputils.inverse_permutation( + np.concatenate(positions)) data = data.take(indices) attrs = OrderedDict(first_var.attrs) @@ -1167,13 +1155,11 @@ def concat(cls, variables, dim='concat_dim', positions=None, def copy(self, deep=True): """Returns a copy of this object. - If `deep=True`, the values array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. + `deep` is ignored since data is stored in the form of pandas.Index, + which is already immutable. Dimensions, attributes and encodings are + always copied. """ - # there is no need to copy the index values here even if deep=True - # since pandas.Index objects are immutable - data = PandasIndexAdapter(self) if deep else self._data - return type(self)(self.dims, data, self._attrs, + return type(self)(self.dims, self._data, self._attrs, self._encoding, fastpath=True) def _data_equals(self, other): @@ -1190,7 +1176,7 @@ def to_index(self): # n.b. creating a new pandas.Index from an old pandas.Index is # basically free as pandas.Index objects are immutable assert self.ndim == 1 - index = self._data_cached().array + index = self._data.array if isinstance(index, pd.MultiIndex): # set default names for multi-index unnamed levels so that # we can safely rename dimension / coordinate later diff --git a/xarray/test/__init__.py b/xarray/test/__init__.py index e2ac8d9c2ce..bdafef7c3ad 100644 --- a/xarray/test/__init__.py +++ b/xarray/test/__init__.py @@ -126,6 +126,18 @@ def data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08): return ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) +def assert_dataset_allclose(d1, d2, rtol=1e-05, atol=1e-08): + assert sorted(d1, key=str) == sorted(d2, key=str) + assert sorted(d1.coords, key=str) == sorted(d2.coords, key=str) + for k in d1: + v1 = d1.variables[k] + v2 = d2.variables[k] + assert v1.dims == v2.dims + allclose = data_allclose_or_equiv( + v1.values, v2.values, rtol=rtol, atol=atol) + assert allclose, (k, v1.values, v2.values) + + class TestCase(unittest.TestCase): if PY3: # Python 3 assertCountEqual is roughly equivalent to Python 2 diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 65e0f3d51ac..9c074147849 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -14,13 +14,15 @@ import numpy as np import pandas as pd +import pytest import xarray as xr from xarray import (Dataset, DataArray, open_dataset, open_dataarray, open_mfdataset, backends, save_mfdataset) from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_encoding -from xarray.core.pycompat import iteritems, PY3 +from xarray.core import indexing +from xarray.core.pycompat import iteritems, PY2, PY3 from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, @@ -38,6 +40,9 @@ pass +ON_WINDOWS = sys.platform == 'win32' + + def open_example_dataset(name, *args, **kwargs): return open_dataset(os.path.join(os.path.dirname(__file__), 'data', name), *args, **kwargs) @@ -130,10 +135,7 @@ def assert_loads(vars=None): with self.roundtrip(expected) as actual: for k, v in actual.variables.items(): # IndexVariables are eagerly loaded into memory - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) yield actual for k, v in actual.variables.items(): if k in vars: @@ -163,24 +165,51 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) for v in computed.variables.values(): self.assertTrue(v._in_memory) self.assertDatasetAllClose(expected, actual) self.assertDatasetAllClose(expected, computed) + def test_pickle(self): + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip( + expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: + raw_pickle = pickle.dumps(roundtripped) + # windows doesn't like opening the same file twice + roundtripped.close() + unpickled_ds = pickle.loads(raw_pickle) + self.assertDatasetIdentical(expected, unpickled_ds) + + def test_pickle_dataarray(self): + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip( + expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: + unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) + self.assertDatasetIdentical(expected['foo'], unpickled_array) + + def test_dataset_caching(self): + expected = Dataset({'foo': ('x', [5, 6, 7])}) + with self.roundtrip(expected) as actual: + assert isinstance(actual.foo.variable._data, + indexing.MemoryCachedArray) + assert not actual.foo.variable._in_memory + actual.foo.values # cache + assert actual.foo.variable._in_memory + + with self.roundtrip(expected, open_kwargs={'cache': False}) as actual: + assert isinstance(actual.foo.variable._data, + indexing.CopyOnWriteArray) + assert not actual.foo.variable._in_memory + actual.foo.values # no caching + assert not actual.foo.variable._in_memory + def test_roundtrip_None_variable(self): expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])}) with self.roundtrip(expected) as actual: @@ -282,11 +311,6 @@ def test_orthogonal_indexing(self): actual = on_disk.isel(**indexers) self.assertDatasetAllClose(expected, actual) - def test_pickle(self): - on_disk = open_example_dataset('bears.nc') - unpickled = pickle.loads(pickle.dumps(on_disk)) - self.assertDatasetIdentical(on_disk, unpickled) - class CFEncodedDataTest(DatasetIOTestCases): @@ -434,13 +458,17 @@ def test_encoding_same_dtype(self): @contextlib.contextmanager -def create_tmp_file(suffix='.nc'): +def create_tmp_file(suffix='.nc', allow_cleanup_failure=False): temp_dir = tempfile.mkdtemp() path = os.path.join(temp_dir, 'temp-%s%s' % (next(_counter), suffix)) try: yield path finally: - shutil.rmtree(temp_dir) + try: + shutil.rmtree(temp_dir) + except OSError: + if not allow_cleanup_failure: + raise class BaseNetCDF4Test(CFEncodedDataTest): @@ -650,8 +678,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, **save_kwargs) with open_dataset(tmp_file, **open_kwargs) as ds: yield ds @@ -689,9 +719,11 @@ def test_unsorted_index_raises(self): @requires_dask class NetCDF4ViaDaskDataTest(NetCDF4DataTest): @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): with NetCDF4DataTest.roundtrip( - self, data, save_kwargs, open_kwargs) as ds: + self, data, save_kwargs, open_kwargs, + allow_cleanup_failure) as ds: yield ds.chunk() def test_unsorted_index_raises(self): @@ -699,6 +731,10 @@ def test_unsorted_index_raises(self): # dask first pulls items by block. pass + def test_dataset_caching(self): + # caching behavior differs for dask + pass + @requires_scipy class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @@ -708,11 +744,20 @@ def create_store(self): yield backends.ScipyDataStore(fobj, 'w') @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): serialized = data.to_netcdf(**save_kwargs) - with open_dataset(BytesIO(serialized), **open_kwargs) as ds: + with open_dataset(serialized, engine='scipy', **open_kwargs) as ds: yield ds + @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') + def test_bytesio_pickle(self): + data = Dataset({'foo': ('x', [1, 2, 3])}) + fobj = BytesIO(data.to_netcdf()) + with open_dataset(fobj) as ds: + unpickled = pickle.loads(pickle.dumps(ds)) + self.assertDatasetIdentical(unpickled, data) + @requires_scipy class ScipyOnDiskDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @@ -723,8 +768,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) with open_dataset(tmp_file, engine='scipy', **open_kwargs) as ds: yield ds @@ -762,8 +809,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='NETCDF3_CLASSIC', engine='netcdf4', **save_kwargs) with open_dataset(tmp_file, engine='netcdf4', **open_kwargs) as ds: @@ -771,7 +820,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_netCDF4 -class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, + TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -780,8 +830,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='NETCDF4_CLASSIC', engine='netcdf4', **save_kwargs) with open_dataset(tmp_file, engine='netcdf4', **open_kwargs) as ds: @@ -798,8 +850,10 @@ def test_write_store(self): pass @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='netcdf3_64bit', **save_kwargs) with open_dataset(tmp_file, **open_kwargs) as ds: yield ds @@ -848,8 +902,10 @@ def create_store(self): yield backends.H5NetCDFStore(tmp_file, 'w') @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='h5netcdf', **save_kwargs) with open_dataset(tmp_file, engine='h5netcdf', **open_kwargs) as ds: yield ds @@ -898,7 +954,8 @@ def create_store(self): yield Dataset() @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): yield data.chunk() def test_roundtrip_datetime_data(self): @@ -912,6 +969,13 @@ def test_write_store(self): # Override method in DatasetIOTestCases - not applicable to dask pass + def test_dataset_caching(self): + expected = Dataset({'foo': ('x', [5, 6, 7])}) + with self.roundtrip(expected) as actual: + assert not actual.foo.variable._in_memory + actual.foo.values # no caching + assert not actual.foo.variable._in_memory + def test_open_mfdataset(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp1: @@ -1041,6 +1105,7 @@ def test_dataarray_compute(self): self.assertTrue(computed._in_memory) self.assertDataArrayAllClose(actual, computed) + @requires_scipy_or_netCDF4 @requires_pydap class PydapTest(TestCase): @@ -1093,8 +1158,10 @@ def test_orthogonal_indexing(self): pass @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) with open_dataset(tmp_file, engine='pynio', **open_kwargs) as ds: yield ds diff --git a/xarray/test/test_conventions.py b/xarray/test/test_conventions.py index 265b2a93b0f..8adb20dced3 100644 --- a/xarray/test/test_conventions.py +++ b/xarray/test/test_conventions.py @@ -191,11 +191,11 @@ def test_cf_datetime(self): @requires_netCDF4 def test_decode_cf_datetime_overflow(self): - # checks for + # checks for # https://github.com/pydata/pandas/issues/14068 # https://github.com/pydata/xarray/issues/975 - from datetime import datetime + from datetime import datetime units = 'days since 2000-01-01 00:00:00' # date after 2262 and before 1678 @@ -626,7 +626,8 @@ def create_store(self): yield CFEncodedInMemoryStore() @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): store = CFEncodedInMemoryStore() data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) diff --git a/xarray/test/test_distributed.py b/xarray/test/test_distributed.py new file mode 100644 index 00000000000..a807f72387a --- /dev/null +++ b/xarray/test/test_distributed.py @@ -0,0 +1,36 @@ +import pytest +import xarray as xr +from xarray.core.pycompat import suppress + +distributed = pytest.importorskip('distributed') +da = pytest.importorskip('dask.array') +from distributed.utils_test import cluster, loop + +from xarray.test.test_backends import create_tmp_file +from xarray.test.test_dataset import create_test_data + +from . import assert_dataset_allclose, has_scipy, has_netCDF4, has_h5netcdf + + +ENGINES = [] +if has_scipy: + ENGINES.append('scipy') +if has_netCDF4: + ENGINES.append('netcdf4') +if has_h5netcdf: + ENGINES.append('h5netcdf') + + +@pytest.mark.parametrize('engine', ENGINES) +def test_dask_distributed_integration_test(loop, engine): + with cluster() as (s, _): + with distributed.Client(('127.0.0.1', s['port']), loop=loop): + original = create_test_data() + with create_tmp_file() as filename: + original.to_netcdf(filename, engine=engine) + # TODO: should be able to serialize locks + restored = xr.open_dataset(filename, chunks=3, lock=False, + engine=engine) + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_dataset_allclose(original, computed) diff --git a/xarray/test/test_indexing.py b/xarray/test/test_indexing.py index ed61c9eb463..9d22b4f2c87 100644 --- a/xarray/test/test_indexing.py +++ b/xarray/test/test_indexing.py @@ -200,3 +200,45 @@ def test_lazily_indexed_array(self): actual = lazy[i][j] self.assertEqual(expected.shape, actual.shape) self.assertArrayEqual(expected, actual) + + +class TestCopyOnWriteArray(TestCase): + def test_setitem(self): + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + wrapped[:] = 0 + self.assertArrayEqual(original, np.arange(10)) + self.assertArrayEqual(wrapped, np.zeros(10)) + + def test_sub_array(self): + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + child = wrapped[:5] + self.assertIsInstance(child, indexing.CopyOnWriteArray) + child[:] = 0 + self.assertArrayEqual(original, np.arange(10)) + self.assertArrayEqual(wrapped, np.arange(10)) + self.assertArrayEqual(child, np.zeros(5)) + + +class TestMemoryCachedArray(TestCase): + def test_wrapper(self): + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + self.assertArrayEqual(wrapped, np.arange(10)) + self.assertIsInstance(wrapped.array, np.ndarray) + + def test_sub_array(self): + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[:5] + self.assertIsInstance(child, indexing.MemoryCachedArray) + self.assertArrayEqual(child, np.arange(5)) + self.assertIsInstance(child.array, np.ndarray) + self.assertIsInstance(wrapped.array, indexing.LazilyIndexedArray) + + def test_setitem(self): + original = np.arange(10) + wrapped = indexing.MemoryCachedArray(original) + wrapped[:] = 0 + self.assertArrayEqual(original, np.zeros(10)) diff --git a/xarray/test/test_utils.py b/xarray/test/test_utils.py index 611b45f80d1..ded618ddcab 100644 --- a/xarray/test/test_utils.py +++ b/xarray/test/test_utils.py @@ -1,6 +1,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import pickle +import pytest + import numpy as np import pandas as pd diff --git a/xarray/test/test_variable.py b/xarray/test/test_variable.py index 5623e78d512..e3eced07716 100644 --- a/xarray/test/test_variable.py +++ b/xarray/test/test_variable.py @@ -450,6 +450,15 @@ def test_multiindex(self): self.assertVariableIdentical(Variable((), ('a', 0)), v[0]) self.assertVariableIdentical(v, v[:]) + def test_load(self): + array = self.cls('x', np.arange(5)) + orig_data = array._data + copied = array.copy(deep=True) + array.load() + assert type(array._data) is type(orig_data) + assert type(copied._data) is type(orig_data) + self.assertVariableIdentical(array, copied) + class TestVariable(TestCase, VariableSubclassTestCases): cls = staticmethod(Variable)