From 289b377129b18e7dc6da8336e958a85be868acbe Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Oct 2018 21:13:41 -0700 Subject: [PATCH] xarray.backends refactor (#2261) * WIP: xarray.backends.file_manager for managing file objects. This is intended to replace both PickleByReconstructionWrapper and DataStorePickleMixin with something more compartmentalized. xref GH2121 * Switch rasterio to use FileManager * lint fixes * WIP: rewrite FileManager to always use an LRUCache * Test coverage * Don't use move_to_end * minor clarification * Switch FileManager.acquire() to a method * Python 2 compat * Update xarray.set_options() to add file_cache_maxsize and validation * Add assert for FILE_CACHE.maxsize * More docstring for FileManager * Add accidentally omited tests for LRUCache * Adapt scipy backend to use FileManager * Stickler fix * Fix failure on Python 2.7 * Finish adjusting backends to use FileManager * Fix bad import * WIP on distributed * More WIP * Fix distributed write tests * Fixes * Minor fixup * whats new * More refactoring: remove state from backends entirely * Cleanup * Fix failing in-memory datastore tests * Fix inaccessible datastore * fix autoclose warnings * Fix PyNIO failures * No longer disable HDF5 file locking We longer need to explicitly HDF5_USE_FILE_LOCKING='FALSE' because we properly close open files. * whats new and default file cache size * Whats new tweak * Refactor default lock logic to backend classes * Rename get_resource_lock -> get_write_lock * Don't acquire unnecessary locks in __getitem__ * Fix bad merge * Fix import * Remove unreachable code --- asv_bench/asv.conf.json | 1 + asv_bench/benchmarks/dataset_io.py | 41 ++++ doc/api.rst | 3 + doc/whats-new.rst | 19 +- xarray/backends/__init__.py | 4 + xarray/backends/api.py | 250 ++++++++++----------- xarray/backends/common.py | 215 +----------------- xarray/backends/file_manager.py | 206 +++++++++++++++++ xarray/backends/h5netcdf_.py | 169 +++++++------- xarray/backends/locks.py | 191 ++++++++++++++++ xarray/backends/lru_cache.py | 91 ++++++++ xarray/backends/memory.py | 3 +- xarray/backends/netCDF4_.py | 231 ++++++++++--------- xarray/backends/pseudonetcdf_.py | 79 ++++--- xarray/backends/pynio_.py | 53 ++--- xarray/backends/rasterio_.py | 94 ++++---- xarray/backends/scipy_.py | 129 +++++------ xarray/backends/zarr.py | 21 +- xarray/core/dataset.py | 25 +-- xarray/core/options.py | 71 ++++-- xarray/core/pycompat.py | 9 +- xarray/tests/test_backends.py | 209 +++++------------ xarray/tests/test_backends_file_manager.py | 114 ++++++++++ xarray/tests/test_backends_locks.py | 13 ++ xarray/tests/test_backends_lru_cache.py | 91 ++++++++ xarray/tests/test_dataset.py | 4 +- xarray/tests/test_distributed.py | 110 +++++---- xarray/tests/test_options.py | 33 +++ 28 files changed, 1496 insertions(+), 983 deletions(-) create mode 100644 xarray/backends/file_manager.py create mode 100644 xarray/backends/locks.py create mode 100644 xarray/backends/lru_cache.py create mode 100644 xarray/tests/test_backends_file_manager.py create mode 100644 xarray/tests/test_backends_locks.py create mode 100644 xarray/tests/test_backends_lru_cache.py diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index b5953436387..e3933b400e6 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -64,6 +64,7 @@ "scipy": [""], "bottleneck": ["", null], "dask": [""], + "distributed": [""], }, diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 0b918e58eab..da18d541a16 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import os + import numpy as np import pandas as pd @@ -14,6 +16,9 @@ pass +os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' + + class IOSingleNetCDF(object): """ A few examples that benchmark reading/writing a single netCDF file with @@ -405,3 +410,39 @@ def time_open_dataset_scipy_with_time_chunks(self): with dask.set_options(get=dask.multiprocessing.get): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks) + + +def create_delayed_write(): + import dask.array as da + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({'vals': (['a'], vals)}) + return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + + +class IOWriteNetCDFDask(object): + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + requires_dask() + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed(object): + def setup(self): + try: + import distributed + except ImportError: + raise NotImplementedError + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() diff --git a/doc/api.rst b/doc/api.rst index d204fab3539..662ef567710 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -624,3 +624,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7cdb1685f5f..d0fec7b0778 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,14 +33,27 @@ v0.11.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ +- Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 128 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. + + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ + +Documentation +~~~~~~~~~~~~~ - Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` without dimension argument will change in the next release. Now we warn a FutureWarning. By `Keisuke Fujii `_. -Documentation -~~~~~~~~~~~~~ - Enhancements ~~~~~~~~~~~~ diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 47a2011a3af..a2f0d79a6d1 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,6 +4,7 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .common import AbstractDataStore +from .file_manager import FileManager, CachingFileManager, DummyFileManager from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore from .pydap_ import PydapDataStore @@ -15,6 +16,9 @@ __all__ = [ 'AbstractDataStore', + 'FileManager', + 'CachingFileManager', + 'DummyFileManager', 'InMemoryDataStore', 'NetCDF4DataStore', 'PydapDataStore', diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2bf13011bd1..65112527045 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -4,6 +4,7 @@ from glob import glob from io import BytesIO from numbers import Number +import warnings import numpy as np @@ -12,8 +13,9 @@ from ..core.combine import auto_combine from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri -from .common import ( - HDF5_LOCK, ArrayWriter, CombinedLock, _get_scheduler, _get_scheduler_lock) +from .common import ArrayWriter +from .locks import _get_scheduler + DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -52,27 +54,6 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_lock(filename, engine): - if filename.endswith('.gz'): - lock = False - else: - if engine is None: - engine = _get_default_engine(filename, allow_remote=True) - - if engine == 'netcdf4': - if is_remote_uri(filename): - lock = False - else: - # TODO: identify netcdf3 files and don't use the global lock - # for them - lock = HDF5_LOCK - elif engine in {'h5netcdf', 'pynio'}: - lock = HDF5_LOCK - else: - lock = False - return lock - - def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -130,29 +111,14 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data -def _get_lock(engine, scheduler, format, path_or_file): - """ Get the lock(s) that apply to a particular scheduler/engine/format""" - - locks = [] - if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']: - locks.append(HDF5_LOCK) - locks.append(_get_scheduler_lock(scheduler, path_or_file)) - - # When we have more than one lock, use the CombinedLock wrapper class - lock = CombinedLock(locks) if len(locks) > 1 else locks[0] - - return lock - - def _finalize_store(write, store): """ Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first - store.sync() store.close() def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -204,12 +170,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, it used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -235,6 +200,14 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ + if autoclose is not None: + warnings.warn( + 'The autoclose argument is no longer used by ' + 'xarray.open_dataset() and is now ignored; it will be removed in ' + 'xarray v0.12. If necessary, you can control the maximum number ' + 'of simultaneous open files with ' + 'xarray.set_options(file_cache_maxsize=...).', + FutureWarning, stacklevel=2) if mask_and_scale is None: mask_and_scale = not engine == 'pseudonetcdf' @@ -272,18 +245,11 @@ def maybe_decode_store(store, lock=False): mask_and_scale, decode_times, concat_characters, decode_coords, engine, chunks, drop_variables) name_prefix = 'open_dataset-%s' % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: ds2 = ds - # protect so that dataset store isn't necessarily closed, e.g., - # streams like BytesIO can't be reopened - # datastore backend is responsible for determining this capability - if store._autoclose: - store.close() - return ds2 if isinstance(filename_or_obj, path_type): @@ -314,36 +280,28 @@ def maybe_decode_store(store, lock=False): engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': - store = backends.NetCDF4DataStore.open(filename_or_obj, - group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.NetCDF4DataStore.open( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj, - **backend_kwargs) + store = backends.PydapDataStore.open( + filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.H5NetCDFStore( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.NioDataStore( + filename_or_obj, lock=lock, **backend_kwargs) elif engine == 'pseudonetcdf': store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, autoclose=autoclose, **backend_kwargs) + filename_or_obj, lock=lock, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) - if lock is None: - lock = _default_lock(filename_or_obj, engine) with close_on_error(store): - return maybe_decode_store(store, lock) + return maybe_decode_store(store) else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' @@ -355,7 +313,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -390,10 +348,6 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -409,12 +363,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -490,7 +443,7 @@ def close(self): def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', preprocess=None, engine=None, lock=None, data_vars='all', coords='different', - autoclose=False, parallel=False, **kwargs): + autoclose=None, parallel=False, **kwargs): """Open multiple files as a single dataset. Requires dask to be installed. See documentation for details on dask [1]. @@ -537,15 +490,11 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. - lock : False, True or threading.Lock, optional - This argument is passed on to :py:func:`dask.array.from_array`. By - default, a per-variable lock is used when reading data from netCDF - files with the netcdf4 and h5netcdf engines to avoid issues with - concurrent access when using dask's multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. data_vars : {'minimal', 'different', 'all' or list of str}, optional These data variables will be concatenated together: * 'minimal': Only data variables in which the dimension already @@ -604,9 +553,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, if not paths: raise IOError('no files to open') - if lock is None: - lock = _default_lock(paths[0], engine) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs) @@ -656,19 +602,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None, - compute=True): + engine=None, encoding=None, unlimited_dims=None, compute=True, + multifile=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file See `Dataset.to_netcdf` for full API docs. - The ``writer`` argument is only for the private use of save_mfdataset. + The ``multifile`` argument is only for the private use of save_mfdataset. """ if isinstance(path_or_file, path_type): path_or_file = str(path_or_file) + if encoding is None: encoding = {} + if path_or_file is None: if engine is None: engine = 'scipy' @@ -676,6 +624,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) + if not compute: + raise NotImplementedError( + 'to_netcdf() with compute=False is not yet implemented when ' + 'returning bytes') elif isinstance(path_or_file, basestring): if engine is None: engine = _get_default_engine(path_or_file) @@ -695,45 +647,78 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if format is not None: format = format.upper() - # if a writer is provided, store asynchronously - sync = writer is None - # handle scheduler specific logic scheduler = _get_scheduler() have_chunks = any(v.chunks for v in dataset.variables.values()) - if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and - engine != 'netcdf4'): + + autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] + if autoclose and engine == 'scipy': raise NotImplementedError("Writing netCDF files with the %s backend " "is not currently supported with dask's %s " "scheduler" % (engine, scheduler)) - lock = _get_lock(engine, scheduler, format, path_or_file) - autoclose = (have_chunks and - scheduler in ['distributed', 'multiprocessing']) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer, - autoclose=autoclose, lock=lock) + kwargs = dict(autoclose=True) if autoclose else {} + store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) if isinstance(unlimited_dims, basestring): unlimited_dims = [unlimited_dims] + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals try: - dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims, compute=compute) + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store(dataset, store, writer, encoding=encoding, + unlimited_dims=unlimited_dims) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + if path_or_file is None: + store.sync() return target.getvalue() finally: - if sync and isinstance(path_or_file, basestring): + if not multifile and compute: store.close() if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) + + +def dump_to_store(dataset, store, writer=None, encoder=None, + encoding=None, unlimited_dims=None): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, + unlimited_dims=unlimited_dims) - if not sync: - return store def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, engine=None, compute=True): @@ -816,22 +801,22 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() if compute else None - stores = [to_netcdf(ds, path, mode, format, group, engine, writer, - compute=compute) - for ds, path, group in zip(datasets, paths, groups)] - - if not compute: - import dask - return dask.delayed(stores) + writers, stores = zip(*[ + to_netcdf(ds, path, mode, format, group, engine, compute=compute, + multifile=True) + for ds, path, group in zip(datasets, paths, groups)]) try: - delayed = writer.sync(compute=compute) - for store in stores: - store.sync() + writes = [w.sync(compute=compute) for w in writers] finally: - for store in stores: - store.close() + if compute: + for store in stores: + store.close() + + if not compute: + import dask + return dask.delayed([dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores)]) def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, @@ -852,13 +837,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, store = backends.ZarrStore.open_group(store=store, mode=mode, synchronizer=synchronizer, - group=group, writer=None) + group=group) - # I think zarr stores should always be sync'd immediately + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute) + dump_to_store(dataset, store, writer, encoding=encoding) + writes = writer.sync(compute=compute) if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) return store diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 99f7698ee92..405d989f4af 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,14 +1,10 @@ from __future__ import absolute_import, division, print_function -import contextlib import logging -import multiprocessing -import threading import time import traceback import warnings from collections import Mapping, OrderedDict -from functools import partial import numpy as np @@ -17,13 +13,6 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -# Import default lock -try: - from dask.utils import SerializableLock - HDF5_LOCK = SerializableLock() -except ImportError: - HDF5_LOCK = threading.Lock() - # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -31,62 +20,6 @@ NONE_VAR_NAME = '__values__' -def _get_scheduler(get=None, collection=None): - """ Determine the dask scheduler that is being used. - - None is returned if not dask scheduler is active. - - See also - -------- - dask.base.get_scheduler - """ - try: - # dask 0.18.1 and later - from dask.base import get_scheduler - actual_get = get_scheduler(get, collection) - except ImportError: - try: - from dask.utils import effective_get - actual_get = effective_get(get, collection) - except ImportError: - return None - - try: - from dask.distributed import Client - if isinstance(actual_get.__self__, Client): - return 'distributed' - except (ImportError, AttributeError): - try: - import dask.multiprocessing - if actual_get == dask.multiprocessing.get: - return 'multiprocessing' - else: - return 'threaded' - except ImportError: - return 'threaded' - - -def _get_scheduler_lock(scheduler, path_or_file=None): - """ Get the appropriate lock for a certain situation based onthe dask - scheduler used. - - See Also - -------- - dask.utils.get_scheduler_lock - """ - - if scheduler == 'distributed': - from dask.distributed import Lock - return Lock(path_or_file) - elif scheduler == 'multiprocessing': - return multiprocessing.Lock() - elif scheduler == 'threaded': - from dask.utils import SerializableLock - return SerializableLock() - else: - return threading.Lock() - - def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -133,39 +66,6 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, time.sleep(1e-3 * next_delay) -class CombinedLock(object): - """A combination of multiple locks. - - Like a locked door, a CombinedLock is locked if any of its constituent - locks are locked. - """ - - def __init__(self, locks): - self.locks = tuple(set(locks)) # remove duplicates - - def acquire(self, *args): - return all(lock.acquire(*args) for lock in self.locks) - - def release(self, *args): - for lock in self.locks: - lock.release(*args) - - def __enter__(self): - for lock in self.locks: - lock.__enter__() - - def __exit__(self, *args): - for lock in self.locks: - lock.__exit__(*args) - - @property - def locked(self): - return any(lock.locked for lock in self.locks) - - def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) - - class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): @@ -174,9 +74,6 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = None - _ds = None - _isopen = False def __iter__(self): return iter(self.variables) @@ -259,7 +156,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=HDF5_LOCK): + def __init__(self, lock=None): self.sources = [] self.targets = [] self.lock = lock @@ -274,6 +171,9 @@ def add(self, source, target): def sync(self, compute=True): if self.sources: import dask.array as da + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernable difference in perforance, e.g., + # targets = [dask.delayed(t) for t in self.targets] delayed_store = da.store(self.sources, self.targets, lock=self.lock, compute=compute, flush=True) @@ -283,11 +183,6 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None, lock=HDF5_LOCK): - if writer is None: - writer = ArrayWriter(lock=lock) - self.writer = writer - self.delayed_store = None def encode(self, variables, attributes): """ @@ -329,12 +224,6 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self, compute=True): - if self._isopen and self._autoclose: - # datastore will be reopened during write - self.close() - self.delayed_store = self.writer.sync(compute=compute) - def store_dataset(self, dataset): """ in stores, variables are all variables AND coordinates @@ -345,7 +234,7 @@ def store_dataset(self, dataset): self.store(dataset, dataset.attrs) def store(self, variables, attributes, check_encoding_set=frozenset(), - unlimited_dims=None): + writer=None, unlimited_dims=None): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -361,16 +250,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. """ + if writer is None: + writer = ArrayWriter() variables, attributes = self.encode(variables, attributes) self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, + self.set_variables(variables, check_encoding_set, writer, unlimited_dims=unlimited_dims) def set_attributes(self, attributes): @@ -386,7 +278,7 @@ def set_attributes(self, attributes): for k, v in iteritems(attributes): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data @@ -399,6 +291,7 @@ def set_variables(self, variables, check_encoding_set, check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -410,7 +303,7 @@ def set_variables(self, variables, check_encoding_set, target, source = self.prepare_variable( name, v, check, unlimited_dims=unlimited_dims) - self.writer.add(source, target) + writer.add(source, target) def set_dimensions(self, variables, unlimited_dims=None): """ @@ -457,87 +350,3 @@ def encode(self, variables, attributes): attributes = OrderedDict([(k, self.encode_attribute(v)) for k, v in attributes.items()]) return variables, attributes - - -class DataStorePickleMixin(object): - """Subclasses must define `ds`, `_opener` and `_mode` attributes. - - Do not subclass this class: it is not part of xarray's external API. - """ - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - del state['_isopen'] - if self._mode == 'w': - # file has already been created, don't override when restoring - state['_mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._ds = None - self._isopen = False - - @property - def ds(self): - if self._ds is not None and self._isopen: - return self._ds - ds = self._opener(mode=self._mode) - self._isopen = True - return ds - - @contextlib.contextmanager - def ensure_open(self, autoclose=None): - """ - Helper function to make sure datasets are closed and opened - at appropriate times to avoid too many open file errors. - - Use requires `autoclose=True` argument to `open_mfdataset`. - """ - - if autoclose is None: - autoclose = self._autoclose - - if not self._isopen: - try: - self._ds = self._opener() - self._isopen = True - yield - finally: - if autoclose: - self.close() - else: - yield - - def assert_open(self): - if not self._isopen: - raise AssertionError('internal failure: file must be open ' - 'if `autoclose=True` is used.') - - -class PickleByReconstructionWrapper(object): - - def __init__(self, opener, file, mode='r', **kwargs): - self.opener = partial(opener, file, mode=mode, **kwargs) - self.mode = mode - self._ds = None - - @property - def value(self): - self._ds = self.opener() - return self._ds - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - if self.mode == 'w': - # file has already been created, don't override when restoring - state['mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def close(self): - self._ds.close() diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py new file mode 100644 index 00000000000..a93285370b2 --- /dev/null +++ b/xarray/backends/file_manager.py @@ -0,0 +1,206 @@ +import threading + +from ..core import utils +from ..core.options import OPTIONS +from .lru_cache import LRUCache + + +# Global cache for storing open files. +FILE_CACHE = LRUCache( + OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, 'file cache must be at least size one' + + +_DEFAULT_MODE = utils.ReprObject('') + + +class FileManager(object): + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self): + """Acquire the file object from this manager.""" + raise NotImplementedError + + def close(self, needs_lock=True): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError + + +class CachingFileManager(FileManager): + """Wrapper for automatically opening and closing file objects. + + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. + + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + + Example usage: + + manager = FileManager(open, 'example.txt', mode='w') + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + + """ + + def __init__(self, opener, *args, **keywords): + """Initialize a FileManager. + + Parameters + ---------- + opener : callable + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). All arguments must be + hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. + """ + # TODO: replace with real keyword arguments when we drop Python 2 + # support + mode = keywords.pop('mode', _DEFAULT_MODE) + kwargs = keywords.pop('kwargs', None) + lock = keywords.pop('lock', None) + cache = keywords.pop('cache', FILE_CACHE) + if keywords: + raise TypeError('FileManager() got unexpected keyword arguments: ' + '%s' % list(keywords)) + + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None or lock is False + self._lock = threading.Lock() if self._default_lock else lock + self._cache = cache + self._key = self._make_key() + + def _make_key(self): + """Make a key for caching files in the LRU cache.""" + value = (self._opener, + self._args, + self._mode, + tuple(sorted(self._kwargs.items()))) + return _HashedSequence(value) + + def acquire(self): + """Acquiring a file object from the manager. + + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a reentrant lock, which ensures that it is + thread-safe. You can safely acquire a file in multiple threads at the + same time, as long as the underlying file object is thread-safe. + + Returns + ------- + An open file object, as returned by ``opener(*args, **kwargs)``. + """ + with self._lock: + try: + file = self._cache[self._key] + except KeyError: + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs['mode'] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == 'w': + # ensure file doesn't get overriden when opened again + self._mode = 'a' + self._key = self._make_key() + self._cache[self._key] = file + return file + + def _close(self): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def close(self, needs_lock=True): + """Explicitly close any associated file object (if necessary).""" + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + if needs_lock: + with self._lock: + self._close() + else: + self._close() + + def __getstate__(self): + """State for pickling.""" + lock = None if self._default_lock else self._lock + return (self._opener, self._args, self._mode, self._kwargs, lock) + + def __setstate__(self, state): + """Restore from a pickle.""" + opener, args, mode, kwargs, lock = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) + + +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. + + Based on what Python uses internally in functools.lru_cache. + + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ + + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) + + def __hash__(self): + return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface. + """ + def __init__(self, value): + self._value = value + + def acquire(self): + return self._value + + def close(self, needs_lock=True): + del needs_lock # ignored + self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 959cd221734..59cd4e84793 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,11 +8,12 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import ( - HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) +from .common import WritableCFDataStore +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( - BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_require_group) + BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, + _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -25,8 +26,9 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] def maybe_decode_bytes(txt): @@ -61,104 +63,102 @@ def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_require_group( + ds = _nc4_require_group( ds, group, mode, create_group=_h5netcdf_create_group) + return GroupWrapper(ds) -class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): +class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False, lock=HDF5_LOCK): + lock=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, - group=group) - self._ds = opener() - if autoclose: - raise NotImplementedError('autoclose=True is not implemented ' - 'for the h5netcdf backend pending ' - 'further exploration, e.g., bug fixes ' - '(in h5netcdf?)') - self._autoclose = False - self._isopen = True + self._manager = CachingFileManager( + _open_h5netcdf_group, filename, mode=mode, + kwargs=dict(group=group)) + + if lock is None: + if mode == 'r': + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) + self.format = format - self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @property + def ds(self): + return self._manager.acquire().value def open_store_variable(self, name, var): import h5py - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - H5NetCDFArrayWrapper(name, self)) - attrs = _read_attributes(var) - - # netCDF4 specific encoding - encoding = { - 'chunksizes': var.chunks, - 'fletcher32': var.fletcher32, - 'shuffle': var.shuffle, - } - # Convert h5py-style compression options to NetCDF4-Python - # style, if possible - if var.compression == 'gzip': - encoding['zlib'] = True - encoding['complevel'] = var.compression_opts - elif var.compression is not None: - encoding['compression'] = var.compression - encoding['compression_opts'] = var.compression_opts - - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - - vlen_dtype = h5py.check_dtype(vlen=var.dtype) - if vlen_dtype is unicode_type: - encoding['dtype'] = str - elif vlen_dtype is not None: # pragma: no cover - # xarray doesn't support writing arbitrary vlen dtypes yet. - pass - else: - encoding['dtype'] = var.dtype + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return FrozenOrderedDict(_read_attributes(self.ds)) + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return self.ds.dimensions + return self.ds.dimensions def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v is None} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if is_unlimited: - self.ds.dimensions[name] = None - self.ds.resize_dimension(name, length) - else: - self.ds.dimensions[name] = length + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self.ds.attrs[key] = value + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -226,18 +226,11 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the h5netcdf backend yet') - with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync(compute=compute) - self.ds.sync() - - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if not ds._closed: - ds.close() - self._isopen = False + def sync(self): + self.ds.sync() + # if self.autoclose: + # self.close() + # super(H5NetCDFStore, self).sync(compute=compute) + + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py new file mode 100644 index 00000000000..f633280ef1d --- /dev/null +++ b/xarray/backends/locks.py @@ -0,0 +1,191 @@ +import multiprocessing +import threading +import weakref + +try: + from dask.utils import SerializableLock +except ImportError: + # no need to worry about serializing the lock + SerializableLock = threading.Lock + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_distributed_lock(key): + from dask.distributed import Lock + return Lock(key) + + +_LOCK_MAKERS = { + None: _get_threaded_lock, + 'threaded': _get_threaded_lock, + 'multiprocessing': _get_multiprocessing_lock, + 'distributed': _get_distributed_lock, +} + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + return _LOCK_MAKERS[scheduler] + + +def _get_scheduler(get=None, collection=None): + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See also + -------- + dask.base.get_scheduler + """ + try: + # dask 0.18.1 and later + from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) + except ImportError: + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + + +def get_write_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py new file mode 100644 index 00000000000..321a1ca4da4 --- /dev/null +++ b/xarray/backends/lru_cache.py @@ -0,0 +1,91 @@ +import collections +import threading + +from ..core.pycompat import move_to_end + + +class LRUCache(collections.MutableMapping): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. + """ + def __init__(self, maxsize, on_evict=None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict: callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + if not isinstance(maxsize, int): + raise TypeError('maxsize must be an integer') + if maxsize < 0: + raise ValueError('maxsize must be non-negative') + self._maxsize = maxsize + self._on_evict = on_evict + self._cache = collections.OrderedDict() + self._lock = threading.RLock() + + def __getitem__(self, key): + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + move_to_end(self._cache, key) + return value + + def _enforce_size_limit(self, capacity): + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key, value): + with self._lock: + if key in self._cache: + # insert the new value at the end + del self._cache[key] + self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._enforce_size_limit(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) + + def __delitem__(self, key): + del self._cache[key] + + def __iter__(self): + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self): + return len(self._cache) + + @property + def maxsize(self): + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size): + """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError('maxsize must be non-negative') + with self._lock: + self._enforce_size_limit(size) + self._maxsize = size diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index dcf092557b8..195d4647534 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -17,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) def get_attrs(self): return self._attributes diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index aa19633020b..08ba085b77e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -13,8 +13,10 @@ from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, - find_root, robust_getitem) + BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .locks import (NETCDFC_LOCK, HDF5_LOCK, + combine_locks, ensure_lock, get_write_lock) +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian @@ -25,6 +27,9 @@ '|': 'native'} +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + class BaseNetCDF4Array(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -42,12 +47,13 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): + with self.datastore.lock: data = self.get_array() data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] @@ -63,20 +69,22 @@ def _getitem(self, key): else: getitem = operator.getitem - with self.datastore.ensure_open(autoclose=True): - try: - array = getitem(self.get_array(), key) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative - # error message. This is most often called when an unsorted - # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) + original_array = self.get_array() + + try: + with self.datastore.lock: + array = getitem(original_array, key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) return array @@ -223,7 +231,17 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, return encoding -def _open_netcdf4_group(filename, mode, group=None, **kwargs): +class GroupWrapper(object): + """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): + self.value = value + + def close(self): + # netCDF4 only allows closing the root group + find_root(self.value).close() + + +def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): import netCDF4 as nc4 ds = nc4.Dataset(filename, mode=mode, **kwargs) @@ -233,7 +251,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return ds + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -279,40 +297,33 @@ def _set_nc_attribute(obj, key, value): obj.setncattr(key, value) -class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): +class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False, lock=HDF5_LOCK): - - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False): + import netCDF4 - _disable_auto_decode_group(netcdf4_dataset) + if isinstance(manager, netCDF4.Dataset): + _disable_auto_decode_group(manager) + manager = DummyFileManager(GroupWrapper(manager)) - self._ds = netcdf4_dataset - self._autoclose = autoclose - self._isopen = True + self._manager = manager self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._mode = mode = 'a' if mode == 'w' else mode - if opener: - self._opener = functools.partial(opener, mode=self._mode) - else: - self._opener = opener - super(NetCDF4DataStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False, - autoclose=False, lock=HDF5_LOCK): - import netCDF4 as nc4 + clobber=True, diskless=False, persist=False, + lock=None, lock_maker=None, autoclose=False): + import netCDF4 if (len(filename) == 88 and - LooseVersion(nc4.__version__) < "1.3.1"): + LooseVersion(netCDF4.__version__) < "1.3.1"): warnings.warn( 'A segmentation fault may occur when the ' 'file path has exactly 88 characters as it does ' @@ -323,86 +334,91 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' - opener = functools.partial(_open_netcdf4_group, filename, mode=mode, - group=group, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - ds = opener() - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose, lock=lock) - def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None + if lock is None: + if mode == 'r': + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith('NETCDF4'): + base_lock = NETCDF4_PYTHON_LOCK else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - encoding['dtype'] = var.dtype + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_write_lock(filename)]) + + manager = CachingFileManager( + _open_netcdf4_group, filename, lock, mode=mode, + kwargs=dict(group=group, clobber=clobber, diskless=diskless, + persist=persist, format=format)) + return cls(manager, lock=lock, autoclose=autoclose) + + @property + def ds(self): + return self._manager.acquire().value + + def open_store_variable(self, name, var): + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - iteritems(self.ds.variables)) + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in + iteritems(self.ds.variables)) return dsvars def get_attrs(self): - with self.ensure_open(autoclose=True): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - with self.ensure_open(autoclose=True): - dims = FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) return dims def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, size=dim_length) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - _set_nc_attribute(self.ds, key, value) - - def set_variables(self, *args, **kwargs): - with self.ensure_open(autoclose=False): - super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) @@ -460,15 +476,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync(compute=compute) - self.ds.sync() + def sync(self): + self.ds.sync() - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if ds._isopen: - ds.close() - self._isopen = False + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 3d846916740..e4691d1f7e1 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,14 +1,18 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.pycompat import OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from ..core.utils import Frozen +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock + + +# psuedonetcdf can invoke netCDF libraries internally +PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) class PncArrayWrapper(BackendArray): @@ -21,7 +25,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.dtype) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -30,57 +33,55 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] -class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin): +class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ @classmethod - def open(cls, filename, format=None, writer=None, - autoclose=False, **format_kwds): + def open(cls, filename, lock=None, **format_kwds): from PseudoNetCDF import pncopen - opener = functools.partial(pncopen, filename, **format_kwds) - ds = opener() - mode = format_kwds.get('mode', 'r') - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) - def __init__(self, pnc_dataset, mode='r', writer=None, opener=None, - autoclose=False): + keywords = dict(kwargs=format_kwds) + # only include mode if explicitly passed + mode = format_kwds.pop('mode', None) + if mode is not None: + keywords['mode'] = mode + + if lock is None: + lock = PNETCDF_LOCK + + manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) + return cls(manager, lock) - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=None): + self._manager = manager + self.lock = ensure_lock(lock) - self._ds = pnc_dataset - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode - super(PseudoNetCDFDataStore, self).__init__() + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - data = indexing.LazilyOuterIndexedArray( - PncArrayWrapper(name, self) - ) + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) return Variable(var.dimensions, data, attrs) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return ((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(dict([(k, getattr(self.ds, k)) - for k in self.ds.ncattrs()])) + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -90,6 +91,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 98b76928597..574fff744e3 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,13 +1,20 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import ( + HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock) + + +# PyNIO can invoke netCDF libraries internally +# Add a dedicated lock just in case NCL as well isn't thread-safe. +NCL_LOCK = SerializableLock() +PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) class NioArrayWrapper(BackendArray): @@ -20,7 +27,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.typecode()) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -28,46 +34,45 @@ def __getitem__(self, key): key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - array = self.get_array() + array = self.get_array() + with self.datastore.lock: if key == () and self.ndim == 0: return array.get_value() - return array[key] -class NioDataStore(AbstractDataStore, DataStorePickleMixin): +class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', autoclose=False): + def __init__(self, filename, mode='r', lock=None): import Nio - opener = functools.partial(Nio.open_file, filename, mode=mode) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if lock is None: + lock = PYNIO_LOCK + self.lock = ensure_lock(lock) + self._manager = CachingFileManager( + Nio.open_file, filename, lock=lock, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') + @property + def ds(self): + return self._manager.acquire() + def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.attributes) + return Frozen(self.ds.attributes) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -76,6 +81,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 44cca9aaaf8..5746b4e748d 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -8,14 +8,13 @@ from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray, PickleByReconstructionWrapper +from .common import BackendArray +from .file_manager import CachingFileManager +from .locks import SerializableLock -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock -RASTERIO_LOCK = Lock() +# TODO: should this be GDAL_LOCK instead? +RASTERIO_LOCK = SerializableLock() _ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' 'valid on rasterio files. Try to load your data with ds.load()' @@ -25,18 +24,22 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, riods): - self.riods = riods - self._shape = (riods.value.count, riods.value.height, - riods.value.width) - self._ndims = len(self.shape) + def __init__(self, manager): + self.manager = manager - @property - def dtype(self): - dtypes = self.riods.value.dtypes + # cannot save riods as an attribute: this would break pickleability + riods = manager.acquire() + + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') - return np.dtype(dtypes[0]) + self._dtype = np.dtype(dtypes[0]) + + @property + def dtype(self): + return self._dtype @property def shape(self): @@ -108,7 +111,8 @@ def _getitem(self, key): stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: - out = self.riods.value.read(band_key, window=window) + riods = self.manager.acquire() + out = riods.read(band_key, window=window) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) @@ -203,7 +207,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') + manager = CachingFileManager(rasterio.open, filename, mode='r') + riods = manager.acquire() if cache is None: cache = chunks is None @@ -211,20 +216,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.value.count < 1: + if riods.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.value.indexes) + coords['band'] = np.asarray(riods.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.value.affine + transform = riods.affine else: - transform = riods.value.transform + transform = riods.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.value.width, riods.value.height + nx, ny = riods.width, riods.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -234,57 +239,60 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # 2d coordinates parse = False if (parse_coordinates is None) else parse_coordinates if parse: - warnings.warn("The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) # Attributes attrs = dict() # Affine transformation matrix (always available) # This describes coefficients mapping pixel coordinates to CRS # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods.value, 'crs') and riods.value.crs: + if hasattr(riods, 'crs') and riods.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.value.crs.to_string() - if hasattr(riods.value, 'res'): + attrs['crs'] = riods.crs.to_string() + if hasattr(riods, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.value.res - if hasattr(riods.value, 'is_tiled'): + attrs['res'] = riods.res + if hasattr(riods, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.value.is_tiled) - if hasattr(riods.value, 'nodatavals'): + attrs['is_tiled'] = np.uint8(riods.is_tiled) + if hasattr(riods, 'nodatavals'): # The nodata values for the raster bands - attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.value.nodatavals]) + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.value.driver + driver = riods.driver if driver in parsers: - meta = parsers[driver](riods.value.tags(ns=driver)) + meta = parsers[driver](riods.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise if (isinstance(v, (list, np.ndarray)) and - len(v) == riods.value.count): + len(v) == riods.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods)) + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) - if cache and (chunks is None): + if cache and chunks is None: data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), @@ -306,6 +314,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=lock) # Make the file closeable - result._file_obj = riods + result._file_obj = manager return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index cd84431f6b7..b009342efb6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import functools import warnings from distutils.version import LooseVersion from io import BytesIO @@ -11,7 +10,9 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict -from .common import BackendArray, DataStorePickleMixin, WritableCFDataStore +from .common import BackendArray, WritableCFDataStore +from .locks import get_write_lock +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -40,31 +41,26 @@ def __init__(self, variable_name, datastore): str(array.dtype.itemsize)) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data def __getitem__(self, key): - with self.datastore.ensure_open(autoclose=True): - data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - return np.array(data, dtype=self.dtype, copy=copy) + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): - data = self.datastore.ds.variables[self.variable_name] - try: - data[key] = value - except TypeError: - if key is Ellipsis: - # workaround for GH: scipy/scipy#6880 - data[:] = value - else: - raise + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -106,7 +102,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise -class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): +class ScipyDataStore(WritableCFDataStore): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -116,7 +112,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False, lock=None): + mmap=None, lock=None): import scipy import scipy.io @@ -140,34 +136,38 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - opener = functools.partial(_open_scipy_netcdf, - filename=filename_or_obj, - mode=mode, mmap=mmap, version=version) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if (lock is None and mode != 'r' and + isinstance(filename_or_obj, basestring)): + lock = get_write_lock(filename_or_obj) + + if isinstance(filename_or_obj, basestring): + manager = CachingFileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, + kwargs=dict(mmap=mmap, version=version)) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager - super(ScipyDataStore, self).__init__(writer, lock=lock) + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(_decode_attrs(self.ds._attributes)) + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -176,22 +176,20 @@ def get_encoding(self): return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, dim_length) + if name in self.ds.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def encode_variable(self, variable): variable = encode_nc3_variable(variable) @@ -219,27 +217,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the scipy backend yet') - with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync(compute=compute) - self.ds.flush() + def sync(self): + self.ds.sync() def close(self): - self.ds.close() - self._isopen = False - - def __exit__(self, type, value, tb): - self.close() - - def __setstate__(self, state): - filename = state['_opener'].keywords['filename'] - if hasattr(filename, 'seek'): - # it's a file-like object - # seek to the start of the file so scipy can read it - filename.seek(0) - super(ScipyDataStore, self).__setstate__(state) - self._ds = None - self._isopen = False + self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 47b90c8a617..5f19c826289 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -217,8 +217,7 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - writer=None): + def open_group(cls, store, mode='r', synchronizer=None, group=None): import zarr min_zarr = '2.2' @@ -230,24 +229,14 @@ def open_group(cls, store, mode='r', synchronizer=None, group=None, "#installation" % min_zarr) zarr_group = zarr.open_group(store=store, mode=mode, synchronizer=synchronizer, path=group) - return cls(zarr_group, writer=writer) + return cls(zarr_group) - def __init__(self, zarr_group, writer=None): + def __init__(self, zarr_group): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path - if writer is None: - # by default, we should not need a lock for writing zarr because - # we do not (yet) allow overlapping chunks during write - zarr_writer = ArrayWriter(lock=False) - else: - zarr_writer = writer - - # do we need to define attributes for all of the opener keyword args? - super(ZarrStore, self).__init__(zarr_writer) - def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, @@ -334,8 +323,8 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) - def sync(self, compute=True): - self.delayed_store = self.writer.sync(compute=compute) + def sync(self): + pass def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4ade15825c6..c8586d1d408 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1161,27 +1161,12 @@ def reset_coords(self, names=None, drop=False, inplace=False): del obj._variables[name] return obj - def dump_to_store(self, store, encoder=None, sync=True, encoding=None, - unlimited_dims=None, compute=True): + def dump_to_store(self, store, **kwargs): """Store dataset contents to a backends.*DataStore object.""" - if encoding is None: - encoding = {} - variables, attrs = conventions.encode_dataset_coordinates(self) - - check_encoding = set() - for k, enc in encoding.items(): - # no need to shallow copy the variable again; that already happened - # in encode_dataset_coordinates - variables[k].encoding = enc - check_encoding.add(k) - - if encoder: - variables, attrs = encoder(variables, attrs) - - store.store(variables, attrs, check_encoding, - unlimited_dims=unlimited_dims) - if sync: - store.sync(compute=compute) + from ..backends.api import dump_to_store + # TODO: rename and/or cleanup this method to make it more consistent + # with to_netcdf() + return dump_to_store(self, store, **kwargs) def to_netcdf(self, path=None, mode='w', format=None, group=None, engine=None, encoding=None, unlimited_dims=None, diff --git a/xarray/core/options.py b/xarray/core/options.py index a6118f02ed3..04ea0be7172 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,11 +1,43 @@ from __future__ import absolute_import, division, print_function +DISPLAY_WIDTH = 'display_width' +ARITHMETIC_JOIN = 'arithmetic_join' +ENABLE_CFTIMEINDEX = 'enable_cftimeindex' +FILE_CACHE_MAXSIZE = 'file_cache_maxsize' +CMAP_SEQUENTIAL = 'cmap_sequential' +CMAP_DIVERGENT = 'cmap_divergent' + OPTIONS = { - 'display_width': 80, - 'arithmetic_join': 'inner', - 'enable_cftimeindex': False, - 'cmap_sequential': 'viridis', - 'cmap_divergent': 'RdBu_r', + DISPLAY_WIDTH: 80, + ARITHMETIC_JOIN: 'inner', + ENABLE_CFTIMEINDEX: False, + FILE_CACHE_MAXSIZE: 128, + CMAP_SEQUENTIAL: 'viridis', + CMAP_DIVERGENT: 'RdBu_r', +} + +_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) + + +def _positive_integer(value): + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + DISPLAY_WIDTH: _positive_integer, + ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), + FILE_CACHE_MAXSIZE: _positive_integer, +} + + +def _set_file_cache_maxsize(value): + from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value + + +_SETTERS = { + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, } @@ -21,6 +53,10 @@ class set_options(object): - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` for time indexes with non-standard calendars or dates outside the Timestamp-valid range. Default: ``False``. + - ``file_cache_maxsize``: maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than your + system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. + Default: 128. - ``cmap_sequential``: colormap to use for nondivergent data plots. Default: ``viridis``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) @@ -28,8 +64,7 @@ class set_options(object): Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) - - You can use ``set_options`` either as a context manager: +f You can use ``set_options`` either as a context manager: >>> ds = xr.Dataset({'x': np.arange(1000)}) >>> with xr.set_options(display_width=40): @@ -47,16 +82,26 @@ class set_options(object): """ def __init__(self, **kwargs): - invalid_options = {k for k in kwargs if k not in OPTIONS} - if invalid_options: - raise ValueError('argument names %r are not in the set of valid ' - 'options %r' % (invalid_options, set(OPTIONS))) self.old = OPTIONS.copy() - OPTIONS.update(kwargs) + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + 'argument name %r is not in the set of valid options %r' + % (k, set(OPTIONS))) + if k in _VALIDATORS and not _VALIDATORS[k](v): + raise ValueError( + 'option %r given an invalid value: %r' % (k, v)) + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) def __enter__(self): return def __exit__(self, type, value, traceback): OPTIONS.clear() - OPTIONS.update(self.old) + self._apply_update(self.old) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 78c26f1e92f..b980bc279b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -28,6 +28,9 @@ def itervalues(d): import builtins from urllib.request import urlretrieve from inspect import getfullargspec as getargspec + + def move_to_end(ordered_dict, key): + ordered_dict.move_to_end(key) else: # pragma: no cover # Python 2 basestring = basestring # noqa @@ -50,6 +53,11 @@ def itervalues(d): from urllib import urlretrieve from inspect import getargspec + def move_to_end(ordered_dict, key): + value = ordered_dict[key] + del ordered_dict[key] + ordered_dict[key] = value + integer_types = native_int_types + (np.integer,) try: @@ -76,7 +84,6 @@ def itervalues(d): except ImportError as e: path_type = () - try: from contextlib import suppress except ImportError: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0d97ed70fa3..43811942d5f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -8,7 +8,6 @@ import shutil import sys import tempfile -import unittest import warnings from io import BytesIO @@ -20,13 +19,13 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import ( - PickleByReconstructionWrapper, robust_getitem) +from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing from xarray.core.pycompat import ( - PY2, ExitStack, basestring, dask_array_type, iteritems) + ExitStack, basestring, dask_array_type, iteritems) +from xarray.core.options import set_options from xarray.tests import mock from . import ( @@ -138,7 +137,6 @@ class NetCDF3Only(object): class DatasetIOTestCases(object): - autoclose = False engine = None file_format = None @@ -172,8 +170,7 @@ def save(self, dataset, path, **kwargs): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine=self.engine, autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds def test_zero_dimensional_variable(self): @@ -1159,10 +1156,10 @@ def test_already_open_dataset(self): v[...] = 42 nc = nc4.Dataset(tmp_file, mode='r') - with backends.NetCDF4DataStore(nc, autoclose=False) as store: - with open_dataset(store) as ds: - expected = Dataset({'x': ((), 42)}) - assert_identical(expected, ds) + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({'x': ((), 42)}) + assert_identical(expected, ds) def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: @@ -1181,7 +1178,6 @@ def test_read_variable_len_strings(self): @requires_netCDF4 class NetCDF4DataTest(BaseNetCDF4Test): - autoclose = False @contextlib.contextmanager def create_store(self): @@ -1247,9 +1243,13 @@ def test_setncattr_string(self): totest.attrs['bar']) assert one_string == totest.attrs['baz'] - -class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): - autoclose = True + def test_autoclose_future_warning(self): + data = create_test_data() + with create_tmp_file() as tmp_file: + self.save(data, tmp_file) + with pytest.warns(FutureWarning): + with self.open(tmp_file, autoclose=True) as actual: + assert_identical(data, actual) @requires_netCDF4 @@ -1290,10 +1290,6 @@ def test_write_inconsistent_chunks(self): assert actual['y'].encoding['chunksizes'] == (100, 50) -class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): - autoclose = True - - @requires_zarr class BaseZarrTest(CFEncodedDataTest): @@ -1571,19 +1567,14 @@ def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') - @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') - def test_bytesio_pickle(self): + def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) - fobj = BytesIO(data.to_netcdf()) - with open_dataset(fobj, autoclose=self.autoclose) as ds: + fobj = data.to_netcdf() + with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) -class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): - autoclose = True - - @requires_scipy class ScipyFileObjectTest(ScipyWriteTest): engine = 'scipy' @@ -1649,10 +1640,6 @@ def test_nc4_scipy(self): open_dataset(tmp_file, engine='scipy') -class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): - autoclose = True - - @requires_netCDF4 class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only): engine = 'netcdf4' @@ -1673,10 +1660,6 @@ def test_encoding_kwarg_vlen_string(self): pass -class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): - autoclose = True - - @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, object): @@ -1691,11 +1674,6 @@ def create_store(self): yield store -class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( - NetCDF4ClassicViaNetCDF4DataTest): - autoclose = True - - @requires_scipy_or_netCDF4 class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy @@ -1772,10 +1750,6 @@ def test_encoding_unlimited_dims(self): assert_equal(ds, actual) -class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): - autoclose = True - - @requires_h5netcdf @requires_netCDF4 class H5NetCDFDataTest(BaseNetCDF4Test): @@ -1789,8 +1763,11 @@ def create_store(self): @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py') def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) - with self.roundtrip(expected) as actual: - assert_equal(expected, actual) + with pytest.warns(FutureWarning): + # TODO: make it possible to write invalid netCDF files from xarray + # without a warning + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535') def test_cross_engine_read_write_netcdf4(self): @@ -1905,25 +1882,24 @@ def test_dump_encodings_h5py(self): assert actual.x.encoding['compression_opts'] is None -# tests pending h5netcdf fix -@unittest.skip -class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): - autoclose = True - - @pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio']) def readengine(request): return request.param -@pytest.fixture(params=[1, 100]) +@pytest.fixture(params=[1, 20]) def nfiles(request): return request.param -@pytest.fixture(params=[True, False]) -def autoclose(request): - return request.param +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize @pytest.fixture(params=[True, False]) @@ -1946,8 +1922,8 @@ def skip_if_not_engine(engine): pytest.importorskip(engine) -def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, - chunks): +def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, + file_cache_maxsize): # skip certain combinations skip_if_not_engine(readengine) @@ -1955,9 +1931,6 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, if not has_dask and parallel: pytest.skip('parallel requires dask') - if readengine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose yet') - if ON_WINDOWS: pytest.skip('Skipping on Windows') @@ -1973,7 +1946,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, # check that calculation on opened datasets works properly actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, - autoclose=autoclose, chunks=chunks) + chunks=chunks) # check that using open_mfdataset returns dask arrays for variables assert isinstance(actual['foo'].data, dask_array_type) @@ -2172,22 +2145,20 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert isinstance(actual.foo.variable.data, da.Array) assert actual.foo.variable.data.chunks == \ ((5, 5),) assert_identical(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: assert actual.foo.variable.data.chunks == \ ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) + open_mfdataset('foo-bar-baz-*.nc') with raises_regex(ValueError, 'wild-card'): - open_mfdataset('http://some/remote/uri', autoclose=self.autoclose) + open_mfdataset('http://some/remote/uri') @requires_pathlib def test_open_mfdataset_pathlib(self): @@ -2198,8 +2169,7 @@ def test_open_mfdataset_pathlib(self): tmp2 = Path(tmp2) original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(original, actual) def test_attrs_mfdataset(self): @@ -2230,8 +2200,7 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess, - autoclose=self.autoclose) as actual: + with open_mfdataset(tmp, preprocess=preprocess) as actual: assert_identical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -2241,8 +2210,7 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_save_mfdataset_invalid(self): @@ -2268,15 +2236,14 @@ def test_save_mfdataset_pathlib_roundtrip(self): tmp1 = Path(tmp1) tmp2 = Path(tmp2) save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) @@ -2286,8 +2253,7 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: assert_identical(data, actual) def test_open_dataset(self): @@ -2314,8 +2280,7 @@ def test_open_single_dataset(self): {'baz': [100]}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset([tmp], concat_dim=dim, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp], concat_dim=dim) as actual: assert_identical(expected, actual) def test_dask_roundtrip(self): @@ -2334,10 +2299,10 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): @@ -2355,41 +2320,22 @@ def test_dataarray_compute(self): assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - def test_to_netcdf_compute_false_roundtrip(self): - from dask.delayed import Delayed - - original = create_test_data().chunk() - - with create_tmp_file() as tmp_file: - # dataset, path, **kwargs): - delayed_obj = self.save(original, tmp_file, compute=False) - assert isinstance(delayed_obj, Delayed) - delayed_obj.compute() - - with self.open(tmp_file) as actual: - assert_identical(original, actual) - def test_save_mfdataset_compute_false_roundtrip(self): from dask.delayed import Delayed original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] - with create_tmp_file() as tmp1: - with create_tmp_file() as tmp2: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], engine=self.engine, compute=False) assert isinstance(delayed_obj, Delayed) delayed_obj.compute() - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) -class DaskTestAutocloseTrue(DaskTest): - autoclose = True - - @requires_scipy_or_netCDF4 @requires_pydap class PydapTest(object): @@ -2500,8 +2446,7 @@ def test_write_store(self): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine='pynio', autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine='pynio', **kwargs) as ds: yield ds def save(self, dataset, path, **kwargs): @@ -2519,19 +2464,12 @@ def test_weakrefs(self): assert_identical(actual, expected) -class PyNioTestAutocloseTrue(PyNioTest): - autoclose = True - - @requires_pseudonetcdf @pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') class PseudoNetCDFFormatTest(object): - autoclose = True def open(self, path, **kwargs): - return open_dataset(path, engine='pseudonetcdf', - autoclose=self.autoclose, - **kwargs) + return open_dataset(path, engine='pseudonetcdf', **kwargs) @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, @@ -2548,7 +2486,6 @@ def test_ict_format(self): """ ictfile = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs={'format': 'ffi1001'}) stdattr = { 'fill_value': -9999.0, @@ -2646,7 +2583,6 @@ def test_ict_format_write(self): fmtkw = {'format': 'ffi1001'} expected = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: @@ -2659,7 +2595,6 @@ def test_uamiv_format_read(self): camxfile = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=True, backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, @@ -2687,7 +2622,6 @@ def test_uamiv_format_mfread(self): ['example.uamiv', 'example.uamiv'], engine='pseudonetcdf', - autoclose=True, concat_dim='TSTEP', backend_kwargs={'format': 'uamiv'}) @@ -2701,11 +2635,11 @@ def test_uamiv_format_mfread(self): data1 = np.array(['2002-06-03'], 'datetime64[ns]') data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable(('TSTEP',), data, - dict(bounds='time_bounds', - long_name=('synthesized time coordinate ' + - 'from SDATE, STIME, STEP ' + - 'global attributes'))) + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) actual = camxfile.variables['time'] assert_allclose(expected, actual) camxfile.close() @@ -2715,7 +2649,6 @@ def test_uamiv_format_write(self): expected = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, @@ -3312,32 +3245,6 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): assert_identical(original_da, loaded_da) -def test_pickle_reconstructor(): - - lines = ['foo bar spam eggs'] - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: - with open(tmp, 'w') as f: - f.writelines(lines) - - obj = PickleByReconstructionWrapper(open, tmp) - - assert obj.value.readlines() == lines - - p_obj = pickle.dumps(obj) - obj.value.close() # for windows - obj2 = pickle.loads(p_obj) - - assert obj2.value.readlines() == lines - - # roundtrip again to make sure we can fully restore the state - p_obj2 = pickle.dumps(obj2) - obj2.value.close() # for windows - obj3 = pickle.loads(p_obj2) - - assert obj3.value.readlines() == lines - - @requires_scipy_or_netCDF4 def test_no_warning_from_dask_effective_get(): with create_tmp_file() as tmpfile: diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py new file mode 100644 index 00000000000..591c981cd45 --- /dev/null +++ b/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,114 @@ +import pickle +import threading +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.file_manager import CachingFileManager +from xarray.backends.lru_cache import LRUCache + + +@pytest.fixture(params=[1, 2, 3, None]) +def file_cache(request): + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) + + +def test_file_manager_mock_write(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) + + manager = CachingFileManager( + opener, 'filename', lock=lock, cache=file_cache) + f = manager.acquire() + f.write('contents') + manager.close() + + assert not file_cache + opener.assert_called_once_with('filename') + mock_file.write.assert_called_once_with('contents') + mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) + + +def test_file_manager_write_consecutive(tmpdir, file_cache): + path1 = str(tmpdir.join('testing1.txt')) + path2 = str(tmpdir.join('testing2.txt')) + manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache) + manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache) + f1a = manager1.acquire() + f1a.write('foo') + f1a.flush() + f2 = manager2.acquire() + f2.write('bar') + f2.flush() + f1b = manager1.acquire() + f1b.write('baz') + assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b) + manager1.close() + manager2.close() + + with open(path1, 'r') as f: + assert f.read() == 'foobaz' + with open(path2, 'r') as f: + assert f.read() == 'bar' + + +def test_file_manager_write_concurrent(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobarbaz' + + +def test_file_manager_write_pickle(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f = manager.acquire() + f.write('foo') + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write('bar') + manager2.close() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +def test_file_manager_read(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + + with open(path, 'w') as f: + f.write('foobar') + + manager = CachingFileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == 'foobar' + manager.close() + + +def test_file_manager_invalid_kwargs(): + with pytest.raises(TypeError): + CachingFileManager(open, 'dummy', mode='w', invalid=True) diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py new file mode 100644 index 00000000000..5f83321802e --- /dev/null +++ b/xarray/tests/test_backends_locks.py @@ -0,0 +1,13 @@ +import threading + +from xarray.backends import locks + + +def test_threaded_lock(): + lock1 = locks._get_threaded_lock('foo') + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock('foo') + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock('bar') + assert lock1 is not lock3 diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 00000000000..03eb6dcf208 --- /dev/null +++ b/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,91 @@ +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + + assert cache['x'] == 1 + assert cache['y'] == 2 + assert len(cache) == 2 + assert dict(cache) == {'x': 1, 'y': 2} + assert list(cache.keys()) == ['x', 'y'] + assert list(cache.items()) == [('x', 1), ('y', 2)] + + cache['z'] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [('y', 2), ('z', 3)] + + +def test_trivial(): + cache = LRUCache(maxsize=0) + cache['x'] = 1 + assert len(cache) == 0 + + +def test_invalid(): + with pytest.raises(TypeError): + LRUCache(maxsize=None) + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + assert list(cache) == ['x', 'y'] + assert 'x' in cache # contains + assert list(cache) == ['y', 'x'] + assert cache['y'] == 2 # getitem + assert list(cache) == ['x', 'y'] + cache['x'] = 3 # setitem + assert list(cache.items()) == [('y', 2), ('x', 3)] + + +def test_del(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + del cache['x'] + assert dict(cache) == {'y': 2} + + +def test_on_evict(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache['x'] = 1 + cache['y'] = 2 + on_evict.assert_called_once_with('x', 1) + + +def test_on_evict_trivial(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache['x'] = 1 + on_evict.assert_called_once_with('x', 1) + + +def test_resize(): + cache = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache['w'] = 0 + cache['x'] = 1 + cache['y'] = 2 + assert list(cache.items()) == [('x', 1), ('y', 2)] + cache.maxsize = 10 + cache['z'] = 3 + assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache.maxsize = 1 + assert list(cache.items()) == [('z', 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9bee965392b..89704653e92 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -64,8 +64,8 @@ def create_test_multiindex(): class InaccessibleVariableDataStore(backends.InMemoryDataStore): - def __init__(self, writer=None): - super(InaccessibleVariableDataStore, self).__init__(writer) + def __init__(self): + super(InaccessibleVariableDataStore, self).__init__() self._indexvars = set() def store(self, variables, *args, **kwargs): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 32035afdc57..7c77a62d3c9 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -15,12 +15,13 @@ from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of +import numpy as np import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, create_tmp_geotiff) from xarray.tests.test_dataset import create_test_data -from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, @@ -33,6 +34,11 @@ da = pytest.importorskip('dask.array') +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join('testfile.nc')) + + ENGINES = [] if has_scipy: ENGINES.append('scipy') @@ -45,81 +51,69 @@ 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], 'h5netcdf': ['NETCDF4']} -TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4'] +ENGINES_AND_FORMATS = [ + ('netcdf4', 'NETCDF3_CLASSIC'), + ('netcdf4', 'NETCDF4_CLASSIC'), + ('netcdf4', 'NETCDF4'), + ('h5netcdf', 'NETCDF4'), + ('scipy', 'NETCDF3_64BIT'), +] -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ['netcdf4']) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop, - engine, autoclose, nc_format): - monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') - - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + loop, tmp_netcdf_filename, engine, nc_format): - original = create_test_data().chunk(chunks) - original.to_netcdf(filename, engine=engine, format=nc_format) - - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) + if engine not in ENGINES: + pytest.skip('engine not available') + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ENGINES) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose, - nc_format): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - if engine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose') + original = create_test_data().chunk(chunks) - if nc_format not in NC_FORMATS[engine]: - pytest.skip('invalid format for engine') + if engine == 'scipy': + with pytest.raises(NotImplementedError): + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + return - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - original = create_test_data() - original.to_netcdf(filename, engine=engine, format=nc_format) - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format): + if engine not in ENGINES: + pytest.skip('engine not available') -@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy']) -def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine): chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data() + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - original = create_test_data().chunk(chunks) + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, + engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - with raises_regex(NotImplementedError, 'distributed'): - original.to_netcdf(filename, engine=engine) @requires_zarr diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index aed96f1acb6..4441375a1b1 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -4,6 +4,7 @@ import xarray from xarray.core.options import OPTIONS +from xarray.backends.file_manager import FILE_CACHE def test_invalid_option_raises(): @@ -11,6 +12,38 @@ def test_invalid_option_raises(): xarray.set_options(not_a_valid_options=True) +def test_display_width(): + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join(): + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join='invalid') + with xarray.set_options(arithmetic_join='exact'): + assert OPTIONS['arithmetic_join'] == 'exact' + + +def test_enable_cftimeindex(): + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS['enable_cftimeindex'] + + +def test_file_cache_maxsize(): + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + def test_nested_options(): original = OPTIONS['display_width'] with xarray.set_options(display_width=1):