From beef5afd83aa611fe4ec4a215ca388eb254046ee Mon Sep 17 00:00:00 2001 From: "Phillip J. Wolfram" Date: Thu, 5 Jan 2017 16:06:37 -0700 Subject: [PATCH] Fixes open_mfdataset too many open file error Includes testing to demonstrate an OSError associated with opening too many files as encountered using open_mfdataset. Fixed for the following backends: * netCDF4 backend * scipy backend * pynio backend Open/close operations on h5netcdf appear to have an error associated with the h5netcdf library following correspondence with @shoyer. Thus, there are still challenges with h5netcdf; hence, support for h5netcdf is currently disabled. Note, by default `autoclose=False` for open_mfdataset so standard behavior is unchanged unless `autoclose=True`. This choice of default is to select standard xarray performance over general removal of the OSError associated with opening too many files as encountered using open_mfdataset. --- doc/whats-new.rst | 6 ++ xarray/backends/api.py | 66 ++++++++------ xarray/backends/common.py | 37 +++++++- xarray/backends/h5netcdf_.py | 74 ++++++++++----- xarray/backends/netCDF4_.py | 165 ++++++++++++++++++--------------- xarray/backends/pynio_.py | 43 +++++---- xarray/backends/scipy_.py | 105 +++++++++++++-------- xarray/core/pycompat.py | 136 ++++++++++++++++++++++++++++ xarray/core/utils.py | 10 +- xarray/tests/__init__.py | 4 + xarray/tests/test_backends.py | 166 +++++++++++++++++++++++++++++----- 11 files changed, 607 insertions(+), 205 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4ddaf01286c..e59ff1d36fa 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,12 @@ v0.9.2 (unreleased) Enhancements ~~~~~~~~~~~~ +- It is now possible to set the ``autoclose=True`` argument to + :py:func:`~xarray.open_mfdataset` to explicitly close opened files when not + in use to prevent occurrence of an OS Error related to too many open files. + Note, the default is ``autoclose=False``, which is consistent with previous + xarray behavior. By `Phillip J. Wolfram `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 770aed952f0..e67db205331 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,7 +1,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gzip import os.path from distutils.version import StrictVersion from glob import glob @@ -133,7 +132,7 @@ def _protect_dataset_variables_inplace(dataset, cache): def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=True, decode_times=True, + mask_and_scale=True, decode_times=True, autoclose=False, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None): """Load and decode a dataset from a file or file-like object. @@ -163,6 +162,10 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. + autoclose : bool, optional + If True, automatically close files to avoid OS Error of too many files + being open. However, this option doesn't work with streams, e.g., + BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -251,6 +254,12 @@ def maybe_decode_store(store, lock=False): else: ds2 = ds + # protect so that dataset store isn't necessarily closed, e.g., + # streams like BytesIO can't be reopened + # datastore backend is responsible for determining this capability + if store._autoclose: + store.close() + return ds2 if isinstance(filename_or_obj, backends.AbstractDataStore): @@ -271,33 +280,30 @@ def maybe_decode_store(store, lock=False): if engine is not None and engine != 'scipy': raise ValueError('can only read gzipped netCDF files with ' "default engine or engine='scipy'") - # if the string ends with .gz, then gunzip and open as netcdf file - try: - store = backends.ScipyDataStore(gzip.open(filename_or_obj)) - except TypeError as e: - # TODO: gzipped loading only works with NetCDF3 files. - if 'is not a valid NetCDF 3 file' in e.message: - raise ValueError('gzipped file loading only supports ' - 'NetCDF 3 files.') - else: - raise - else: - if engine is None: - engine = _get_default_engine(filename_or_obj, - allow_remote=True) - if engine == 'netcdf4': - store = backends.NetCDF4DataStore(filename_or_obj, group=group) - elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj) - elif engine == 'pydap': - store = backends.PydapDataStore(filename_or_obj) - elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group) - elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj) else: - raise ValueError('unrecognized engine for open_dataset: %r' - % engine) + engine = 'scipy' + + if engine is None: + engine = _get_default_engine(filename_or_obj, + allow_remote=True) + if engine == 'netcdf4': + store = backends.NetCDF4DataStore(filename_or_obj, group=group, + autoclose=autoclose) + elif engine == 'scipy': + store = backends.ScipyDataStore(filename_or_obj, + autoclose=autoclose) + elif engine == 'pydap': + store = backends.PydapDataStore(filename_or_obj) + elif engine == 'h5netcdf': + store = backends.H5NetCDFStore(filename_or_obj, group=group, + autoclose=autoclose) + elif engine == 'pynio': + store = backends.NioDataStore(filename_or_obj, + autoclose=autoclose) + else: + raise ValueError('unrecognized engine for open_dataset: %r' + % engine) + if lock is None: lock = _default_lock(filename_or_obj, engine) with close_on_error(store): @@ -479,6 +485,10 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. + autoclose : bool, optional + If True, automatically close files to avoid OS Error of too many files + being open. However, this option doesn't work with streams, e.g., + BytesIO. lock : False, True or threading.Lock, optional This argument is passed on to :py:func:`dask.array.from_array`. By default, a per-variable lock is used when reading data from netCDF diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e7cbd0bd9ae..86c8f783d90 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -5,6 +5,7 @@ import logging import time import traceback +import contextlib from collections import Mapping from distutils.version import StrictVersion @@ -40,6 +41,14 @@ def _decode_variable_name(name): name = None return name +def find_root(ds): + """ + Helper function to find the root of a netcdf or h5netcdf dataset. + """ + while ds.parent is not None: + ds = ds.parent + return ds + def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500): @@ -67,6 +76,7 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, class AbstractDataStore(Mapping): + _autoclose = False def __iter__(self): return iter(self.variables) @@ -107,8 +117,8 @@ def load(self): This function will be called anytime variables or attributes are requested, so care should be taken to make sure its fast. """ - variables = FrozenOrderedDict((_decode_variable_name(k), v) for k, v in - iteritems(self.get_variables())) + variables = FrozenOrderedDict((_decode_variable_name(k), v) + for k, v in self.get_variables().items()) attributes = FrozenOrderedDict(self.get_attrs()) return variables, attributes @@ -252,3 +262,26 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) self.ds = self._opener(mode=self._mode) + + @contextlib.contextmanager + def ensure_open(self, autoclose): + """ + Helper function to make sure datasets are closed and opened + at appropriate times to avoid too many open file errors. + + Use requires `autoclose=True` argument to `open_mfdataset`. + """ + if self._autoclose and not self._isopen: + try: + self.ds = self._opener() + self._isopen = True + yield + finally: + if autoclose: + self.close() + else: + yield + + def assert_open(self): + if not self._isopen: + raise AssertionError('internal failure: file must be open if `autoclose=True` is used.') diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index acb46bee14c..d2e1789173d 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -2,18 +2,27 @@ from __future__ import division from __future__ import print_function import functools +import operator import warnings +import numpy as np + from .. import Variable from ..core import indexing from ..core.utils import FrozenOrderedDict, close_on_error, Frozen from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict -from .common import WritableCFDataStore, DataStorePickleMixin +from .common import WritableCFDataStore, DataStorePickleMixin, find_root from .netCDF4_ import (_nc4_group, _nc4_values_and_dtype, _extract_nc4_variable_encoding, BaseNetCDF4Array) +class H5NetCDFFArrayWrapper(BaseNetCDF4Array): + def __getitem__(self, key): + with self.datastore.ensure_open(autoclose=True): + return self.get_array()[key] + + def maybe_decode_bytes(txt): if isinstance(txt, bytes_type): return txt.decode('utf-8') @@ -49,12 +58,18 @@ class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None): + writer=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, group=group) self.ds = opener() + if autoclose: + raise NotImplemented('autoclose=True is not implemented ' + 'for the h5netcdf backend pending further ' + 'exploration, e.g., bug fixes (in h5netcdf?)') + self._autoclose = False + self._isopen = True self.format = format self._opener = opener self._filename = filename @@ -62,36 +77,43 @@ def __init__(self, filename, mode='r', format=None, group=None, super(H5NetCDFStore, self).__init__(writer) def open_store_variable(self, name, var): - dimensions = var.dimensions - data = indexing.LazilyIndexedArray(BaseNetCDF4Array(name, self)) - attrs = _read_attributes(var) + with self.ensure_open(autoclose=False): + dimensions = var.dimensions + data = indexing.LazilyIndexedArray( + H5NetCDFFArrayWrapper(name, self)) + attrs = _read_attributes(var) - # netCDF4 specific encoding - encoding = dict(var.filters()) - chunking = var.chunking() - encoding['chunksizes'] = chunking if chunking != 'contiguous' else None + # netCDF4 specific encoding + encoding = dict(var.filters()) + chunking = var.chunking() + encoding['chunksizes'] = chunking if chunking != 'contiguous' else None - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape return Variable(dimensions, data, attrs, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + with self.ensure_open(autoclose=False): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - return Frozen(_read_attributes(self.ds)) + with self.ensure_open(autoclose=True): + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - return self.ds.dimensions + with self.ensure_open(autoclose=True): + return self.ds.dimensions def set_dimension(self, name, length): - self.ds.createDimension(name, size=length) + with self.ensure_open(autoclose=False): + self.ds.createDimension(name, size=length) def set_attribute(self, key, value): - self.ds.setncattr(key, value) + with self.ensure_open(autoclose=False): + self.ds.setncattr(key, value) def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): @@ -129,12 +151,14 @@ def prepare_variable(self, name, variable, check_encoding=False, return nc4_var, variable.data def sync(self): - super(H5NetCDFStore, self).sync() - self.ds.sync() + with self.ensure_open(autoclose=True): + super(H5NetCDFStore, self).sync() + self.ds.sync() def close(self): - ds = self.ds - # netCDF4 only allows closing the root group - while ds.parent is not None: - ds = ds.parent - ds.close() + if self._isopen: + # netCDF4 only allows closing the root group + ds = find_root(self.ds) + if not ds._closed: + ds.close() + self._isopen = False diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index f72a26affaf..3584f30750d 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,11 +9,11 @@ from .. import Variable from ..conventions import pop_to from ..core import indexing -from ..core.utils import (FrozenOrderedDict, NDArrayMixin, +from ..core.utils import (FrozenOrderedDict, NdimSizeLenMixin, DunderArrayMixin, close_on_error, is_remote_uri) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 -from .common import WritableCFDataStore, robust_getitem, DataStorePickleMixin +from .common import WritableCFDataStore, robust_getitem, DataStorePickleMixin, find_root from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable, maybe_convert_to_char_array) @@ -25,24 +25,25 @@ '|': 'native'} -class BaseNetCDF4Array(NDArrayMixin): +class BaseNetCDF4Array(NdimSizeLenMixin, DunderArrayMixin): def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name - @property - def array(self): - return self.datastore.ds.variables[self.variable_name] + array = self.get_array() + self.shape = array.shape - @property - def dtype(self): - dtype = self.array.dtype + dtype = array.dtype if dtype is str: - # return object dtype because that's the only way in numpy to + # use object dtype because that's the only way in numpy to # represent variable length strings; it also prevents automatic # string concatenation via conventions.decode_cf_variable dtype = np.dtype('O') - return dtype + self.dtype = dtype + + def get_array(self): + self.datastore.assert_open() + return self.datastore.ds.variables[self.variable_name] class NetCDF4ArrayWrapper(BaseNetCDF4Array): @@ -52,19 +53,20 @@ def __getitem__(self, key): else: getitem = operator.getitem - try: - data = getitem(self.array, key) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative error - # message. This is most often called when an unsorted indexer is - # used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform is ' - 'not valid on netCDF4.Variable object. Try loading your ' - 'data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) + with self.datastore.ensure_open(autoclose=True): + try: + data = getitem(self.get_array(), key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) if self.ndim == 0: # work around for netCDF4-python's broken handling of 0-d @@ -195,7 +197,8 @@ class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ def __init__(self, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False): + writer=None, clobber=True, diskless=False, persist=False, + autoclose=False): if format is None: format = 'NETCDF4' opener = functools.partial(_open_netcdf4_group, filename, mode=mode, @@ -203,66 +206,83 @@ def __init__(self, filename, mode='r', format='NETCDF4', group=None, diskless=diskless, persist=persist, format=format) self.ds = opener() + self._autoclose = autoclose + self._isopen = True self.format = format self.is_remote = is_remote_uri(filename) - self._opener = opener self._filename = filename self._mode = 'a' if mode == 'w' else mode + self._opener = functools.partial(opener, mode=self._mode) super(NetCDF4DataStore, self).__init__(writer) def open_store_variable(self, name, var): - dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None - else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape + with self.ensure_open(autoclose=False): + dimensions = var.dimensions + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + return Variable(dimensions, data, attributes, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + with self.ensure_open(autoclose=False): + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) + return dsvars def get_attrs(self): - return FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + with self.ensure_open(autoclose=True): + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) + return attrs def get_dimensions(self): - return FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + with self.ensure_open(autoclose=True): + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) + return dims def get_encoding(self): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + with self.ensure_open(autoclose=True): + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length): - self.ds.createDimension(name, size=length) + with self.ensure_open(autoclose=False): + self.ds.createDimension(name, size=length) def set_attribute(self, key, value): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - self.ds.setncattr(key, value) + with self.ensure_open(autoclose=False): + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + self.ds.setncattr(key, value) + + def set_variables(self, *args, **kwargs): + with self.ensure_open(autoclose=False): + super(NetCDF4DataStore, self).set_variables(*args, **kwargs) def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): @@ -305,15 +325,18 @@ def prepare_variable(self, name, variable, check_encoding=False, # set attributes one-by-one since netCDF4<1.0.10 can't handle # OrderedDict as the input to setncatts nc4_var.setncattr(k, v) + return nc4_var, variable.data def sync(self): - super(NetCDF4DataStore, self).sync() - self.ds.sync() + with self.ensure_open(autoclose=True): + super(NetCDF4DataStore, self).sync() + self.ds.sync() def close(self): - ds = self.ds - # netCDF4 only allows closing the root group - while ds.parent is not None: - ds = ds.parent - ds.close() + if self._isopen: + # netCDF4 only allows closing the root group + ds = find_root(self.ds) + if ds._isopen: + ds.close() + self._isopen = False diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 8bb759503b9..449971a9145 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -7,39 +7,43 @@ import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin +from ..core.utils import (FrozenOrderedDict, Frozen, + NdimSizeLenMixin, DunderArrayMixin) from ..core import indexing from .common import AbstractDataStore, DataStorePickleMixin -class NioArrayWrapper(NDArrayMixin): +class NioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin): def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name + array = self.get_array() + self.shape = array.shape + self.dtype = np.dtype(array.typecode()) - @property - def array(self): + def get_array(self): + self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] - @property - def dtype(self): - return np.dtype(self.array.typecode()) - def __getitem__(self, key): - if key == () and self.ndim == 0: - return self.array.get_value() - return self.array[key] + with self.datastore.ensure_open(autoclose=True): + array = self.get_array() + if key == () and self.ndim == 0: + return array.get_value() + return array[key] class NioDataStore(AbstractDataStore, DataStorePickleMixin): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r'): + def __init__(self, filename, mode='r', autoclose=False): import Nio opener = functools.partial(Nio.open_file, filename, mode=mode) self.ds = opener() + self._autoclose = autoclose + self._isopen = True self._opener = opener self._mode = mode @@ -48,14 +52,17 @@ def open_store_variable(self, name, var): return Variable(var.dimensions, data, var.attributes) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.iteritems()) + with self.ensure_open(autoclose=False): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.iteritems()) def get_attrs(self): - return Frozen(self.ds.attributes) + with self.ensure_open(autoclose=True): + return Frozen(self.ds.attributes) def get_dimensions(self): - return Frozen(self.ds.dimensions) + with self.ensure_open(autoclose=True): + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -64,4 +71,6 @@ def get_encoding(self): return encoding def close(self): - self.ds.close() + if self._isopen: + self.ds.close() + self._isopen = False diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 281d23fa40a..05ce3a42ce5 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -9,7 +9,8 @@ from .. import Variable from ..core.pycompat import iteritems, OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict +from ..core.utils import (Frozen, FrozenOrderedDict, NdimSizeLenMixin, + DunderArrayMixin) from ..core.indexing import NumpyIndexingAdapter from .common import WritableCFDataStore, DataStorePickleMixin @@ -30,39 +31,46 @@ def _decode_attrs(d): for (k, v) in iteritems(d)) -class ScipyArrayWrapper(NumpyIndexingAdapter): +class ScipyArrayWrapper(NdimSizeLenMixin, DunderArrayMixin): def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name + array = self.get_array() + self.shape = array.shape + self.dtype = np.dtype(array.dtype.kind + + str(array.dtype.itemsize)) - @property - def array(self): + def get_array(self): + self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data - @property - def dtype(self): - # always use native endianness - return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize)) - def __getitem__(self, key): - data = super(ScipyArrayWrapper, self).__getitem__(key) - # Copy data if the source file is mmapped. This makes things consistent - # with the netCDF4 library by ensuring we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - data = np.array(data, dtype=self.dtype, copy=copy) - return data + with self.datastore.ensure_open(autoclose=True): + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. + # This makes things consistent + # with the netCDF4 library by ensuring + # we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) def _open_scipy_netcdf(filename, mode, mmap, version): import scipy.io - if isinstance(filename, bytes) and filename.startswith(b'CDF'): - # it's a NetCDF3 bytestring - filename = BytesIO(filename) + try: + ds = scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, + version=version) + except TypeError as e: + # TODO: gzipped loading only works with NetCDF3 files. + if 'is not a valid NetCDF 3 file' in e.message: + raise ValueError('gzipped file loading only supports ' + 'NetCDF 3 files.') + else: + raise - return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, - version=version) + return ds class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): @@ -74,9 +82,11 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): It only supports the NetCDF3 file-format. """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None): + writer=None, mmap=None, autoclose=False): import scipy import scipy.io + import gzip + if mode != 'r' and scipy.__version__ < '0.13': # pragma: no cover warnings.warn('scipy %s detected; ' 'the minimal recommended version is 0.13. ' @@ -96,28 +106,44 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) + # if the string ends with .gz, then gunzip and open as netcdf file + if type(filename_or_obj) is str and filename_or_obj.endswith('.gz'): + filename_or_obj = gzip.open(filename_or_obj) + + if isinstance(filename_or_obj, bytes) and filename_or_obj.startswith(b'CDF'): + # it's a NetCDF3 bytestring + filename_or_obj = BytesIO(filename_or_obj) + # cannot reopen bytestring after it is closed + autoclose = False + opener = functools.partial(_open_scipy_netcdf, filename=filename_or_obj, mode=mode, mmap=mmap, version=version) self.ds = opener() + self._autoclose = autoclose + self._isopen = True self._opener = opener self._mode = mode super(ScipyDataStore, self).__init__(writer) def open_store_variable(self, name, var): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + with self.ensure_open(autoclose=False): + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + with self.ensure_open(autoclose=False): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - return Frozen(_decode_attrs(self.ds._attributes)) + with self.ensure_open(autoclose=True): + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - return Frozen(self.ds.dimensions) + with self.ensure_open(autoclose=True): + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -126,19 +152,21 @@ def get_encoding(self): return encoding def set_dimension(self, name, length): - if name in self.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - self.ds.createDimension(name, length) + with self.ensure_open(autoclose=False): + if name in self.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + self.ds.createDimension(name, length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + with self.ensure_open(autoclose=False): + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): @@ -163,11 +191,13 @@ def prepare_variable(self, name, variable, check_encoding=False, return scipy_var, data def sync(self): - super(ScipyDataStore, self).sync() - self.ds.flush() + with self.ensure_open(autoclose=True): + super(ScipyDataStore, self).sync() + self.ds.flush() def close(self): self.ds.close() + self._isopen = False def __exit__(self, type, value, tb): self.close() @@ -179,3 +209,4 @@ def __setstate__(self, state): # seek to the start of the file so scipy can read it filename.seek(0) super(ScipyDataStore, self).__setstate__(state) + self._isopen = True diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 671f38c7df7..63f87b0d70f 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -87,3 +87,139 @@ def __exit__(self, exctype, excinst, exctb): # # See http://bugs.python.org/issue12029 for more details return exctype is not None and issubclass(exctype, self._exceptions) +try: + from contextlib import ExitStack +except ImportError: + # backport from Python 3.5: + from collections import deque + + # Inspired by discussions on http://bugs.python.org/issue13585 + class ExitStack(object): + """Context manager for dynamic management of a stack of exit callbacks + + For example: + + with ExitStack() as stack: + files = [stack.enter_context(open(fname)) for fname in filenames] + # All opened files will automatically be closed at the end of + # the with statement, even if attempts to open files later + # in the list raise an exception + + """ + def __init__(self): + self._exit_callbacks = deque() + + def pop_all(self): + """Preserve the context stack by transferring it to a new instance""" + new_stack = type(self)() + new_stack._exit_callbacks = self._exit_callbacks + self._exit_callbacks = deque() + return new_stack + + def _push_cm_exit(self, cm, cm_exit): + """Helper to correctly register callbacks to __exit__ methods""" + def _exit_wrapper(*exc_details): + return cm_exit(cm, *exc_details) + _exit_wrapper.__self__ = cm + self.push(_exit_wrapper) + + def push(self, exit): + """Registers a callback with the standard __exit__ method signature + + Can suppress exceptions the same way __exit__ methods can. + + Also accepts any object with an __exit__ method (registering a call + to the method instead of the object itself) + """ + # We use an unbound method rather than a bound method to follow + # the standard lookup behaviour for special methods + _cb_type = type(exit) + try: + exit_method = _cb_type.__exit__ + except AttributeError: + # Not a context manager, so assume its a callable + self._exit_callbacks.append(exit) + else: + self._push_cm_exit(exit, exit_method) + return exit # Allow use as a decorator + + def callback(self, callback, *args, **kwds): + """Registers an arbitrary callback and arguments. + + Cannot suppress exceptions. + """ + def _exit_wrapper(exc_type, exc, tb): + callback(*args, **kwds) + # We changed the signature, so using @wraps is not appropriate, but + # setting __wrapped__ may still help with introspection + _exit_wrapper.__wrapped__ = callback + self.push(_exit_wrapper) + return callback # Allow use as a decorator + + def enter_context(self, cm): + """Enters the supplied context manager + + If successful, also pushes its __exit__ method as a callback and + returns the result of the __enter__ method. + """ + # We look up the special methods on the type to match the with statement + _cm_type = type(cm) + _exit = _cm_type.__exit__ + result = _cm_type.__enter__(cm) + self._push_cm_exit(cm, _exit) + return result + + def close(self): + """Immediately unwind the context stack""" + self.__exit__(None, None, None) + + def __enter__(self): + return self + + def __exit__(self, *exc_details): + received_exc = exc_details[0] is not None + + # We manipulate the exception state so it behaves as though + # we were actually nesting multiple with statements + frame_exc = sys.exc_info()[1] + def _fix_exception_context(new_exc, old_exc): + # Context may not be correct, so find the end of the chain + while 1: + exc_context = new_exc.__context__ + if exc_context is old_exc: + # Context is already set correctly (see issue 20317) + return + if exc_context is None or exc_context is frame_exc: + break + new_exc = exc_context + # Change the end of the chain to point to the exception + # we expect it to reference + new_exc.__context__ = old_exc + + # Callbacks are invoked in LIFO order to match the behaviour of + # nested context managers + suppressed_exc = False + pending_raise = False + while self._exit_callbacks: + cb = self._exit_callbacks.pop() + try: + if cb(*exc_details): + suppressed_exc = True + pending_raise = False + exc_details = (None, None, None) + except: + new_exc_details = sys.exc_info() + # simulate the stack of exceptions by setting the context + _fix_exception_context(new_exc_details[1], exc_details[1]) + pending_raise = True + exc_details = new_exc_details + if pending_raise: + try: + # bare "raise exc_details[1]" replaces our carefully + # set-up context + fixed_ctx = exc_details[1].__context__ + raise exc_details[1] + except BaseException: + exc_details[1].__context__ = fixed_ctx + raise + return received_exc and suppressed_exc diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 73dc0614fd8..6fd997f5b2b 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -398,7 +398,12 @@ def __len__(self): raise TypeError('len() of unsized object') -class NDArrayMixin(NdimSizeLenMixin): +class DunderArrayMixin(object): + def __array__(self, dtype=None): + return np.asarray(self[...], dtype=dtype) + + +class NDArrayMixin(NdimSizeLenMixin, DunderArrayMixin): """Mixin class for making wrappers of N-dimensional arrays that conform to the ndarray interface required for the data argument to Variable objects. @@ -413,9 +418,6 @@ def dtype(self): def shape(self): return self.array.shape - def __array__(self, dtype=None): - return np.asarray(self[...], dtype=dtype) - def __getitem__(self, key): return self.array[key] diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index e64ad5842a0..843b8bfdfbd 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -7,6 +7,7 @@ import numpy as np from numpy.testing import assert_array_equal +from xarray.core.ops import allclose_or_equiv import pytest from xarray.core import utils @@ -146,6 +147,9 @@ def assertArrayEqual(self, a1, a2): def assertEqual(self, a1, a2): assert a1 == a2 or (a1 != a1 and a2 != a2) + def assertAllClose(self, a1, a2, rtol=1e-05, atol=1e-8): + assert allclose_or_equiv(a1, a2, rtol=rtol, atol=atol) + def assertDatasetEqual(self, d1, d2): assert_equal(d1, d2) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1b4be911bbc..8e26da70426 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -22,7 +22,7 @@ from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.core import indexing -from xarray.core.pycompat import iteritems, PY2, PY3 +from xarray.core.pycompat import iteritems, PY2, PY3, ExitStack from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, @@ -98,6 +98,7 @@ class Only32BitTypes(object): class DatasetIOTestCases(object): + autoclose = False def create_store(self): raise NotImplementedError @@ -332,8 +333,10 @@ def test_roundtrip_strings_with_fill_value(self): # variables expected['x'][-1] = '' elif (isinstance(self, (NetCDF3ViaNetCDF4DataTest, - NetCDF4ClassicViaNetCDF4DataTest)) or - (has_netCDF4 and type(self) is GenericNetCDFDataTest)): + NetCDF4ClassicViaNetCDF4DataTest)) + or (has_netCDF4 and + (type(self) is GenericNetCDFDataTest or + type(self) is GenericNetCDFDataTestAutocloseTrue))): # netCDF4 can't keep track of an empty _FillValue for nc3, either: # https://github.com/Unidata/netcdf4-python/issues/273 expected['x'][-1] = np.string_('') @@ -492,6 +495,14 @@ def create_tmp_file(suffix='.nc', allow_cleanup_failure=False): if not allow_cleanup_failure: raise +@contextlib.contextmanager +def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False): + with ExitStack() as stack: + files = [stack.enter_context(create_tmp_file(suffix, allow_cleanup_failure)) + for apath in np.arange(nfiles)] + yield files + + @requires_netCDF4 class BaseNetCDF4Test(CFEncodedDataTest): def test_open_group(self): @@ -693,6 +704,7 @@ def test_variable_len_strings(self): @requires_netCDF4 class NetCDF4DataTest(BaseNetCDF4Test, TestCase): + autoclose = False @contextlib.contextmanager def create_store(self): @@ -706,7 +718,7 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, with create_tmp_file( allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, **save_kwargs) - with open_dataset(tmp_file, **open_kwargs) as ds: + with open_dataset(tmp_file, autoclose=self.autoclose, **open_kwargs) as ds: yield ds def test_variable_order(self): @@ -737,6 +749,9 @@ def test_unsorted_index_raises(self): except IndexError as err: self.assertIn('first by calling .load', str(err)) +class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): + autoclose = True + @requires_netCDF4 @requires_dask @@ -759,6 +774,10 @@ def test_dataset_caching(self): pass +class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): + autoclose = True + + @requires_scipy class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @contextlib.contextmanager @@ -770,18 +789,23 @@ def create_store(self): def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False): serialized = data.to_netcdf(**save_kwargs) - with open_dataset(serialized, engine='scipy', **open_kwargs) as ds: + with open_dataset(serialized, engine='scipy', + autoclose=self.autoclose, **open_kwargs) as ds: yield ds @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') def test_bytesio_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) fobj = BytesIO(data.to_netcdf()) - with open_dataset(fobj) as ds: + with open_dataset(fobj, autoclose=self.autoclose) as ds: unpickled = pickle.loads(pickle.dumps(ds)) self.assertDatasetIdentical(unpickled, data) +class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): + autoclose = True + + @requires_scipy class ScipyOnDiskDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @contextlib.contextmanager @@ -796,7 +820,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, with create_tmp_file( allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) - with open_dataset(tmp_file, engine='scipy', **open_kwargs) as ds: + with open_dataset(tmp_file, engine='scipy', + autoclose=self.autoclose, **open_kwargs) as ds: yield ds def test_array_attrs(self): @@ -821,6 +846,9 @@ def test_netcdf3_endianness(self): for var in expected.values(): self.assertTrue(var.dtype.isnative) +class ScipyOnDiskDataTestAutocloseTrue(ScipyOnDiskDataTest): + autoclose = True + @requires_netCDF4 class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @@ -838,9 +866,13 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='NETCDF3_CLASSIC', engine='netcdf4', **save_kwargs) - with open_dataset(tmp_file, engine='netcdf4', **open_kwargs) as ds: + with open_dataset(tmp_file, engine='netcdf4', + autoclose=self.autoclose, **open_kwargs) as ds: yield ds +class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): + autoclose = True + @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, @@ -859,10 +891,15 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='NETCDF4_CLASSIC', engine='netcdf4', **save_kwargs) - with open_dataset(tmp_file, engine='netcdf4', **open_kwargs) as ds: + with open_dataset(tmp_file, engine='netcdf4', + autoclose=self.autoclose, **open_kwargs) as ds: yield ds +class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue(NetCDF4ClassicViaNetCDF4DataTest): + autoclose = True + + @requires_scipy_or_netCDF4 class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): # verify that we can read and write netCDF3 files as long as we have scipy @@ -878,7 +915,7 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, with create_tmp_file( allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='netcdf3_64bit', **save_kwargs) - with open_dataset(tmp_file, **open_kwargs) as ds: + with open_dataset(tmp_file, autoclose=self.autoclose, **open_kwargs) as ds: yield ds def test_engine(self): @@ -931,6 +968,8 @@ def test_encoding_unlimited_dims(self): self.assertEqual(actual.encoding['unlimited_dims'], set('y')) self.assertDatasetEqual(ds, actual) +class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): + autoclose = True @requires_h5netcdf @requires_netCDF4 @@ -946,7 +985,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, with create_tmp_file( allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='h5netcdf', **save_kwargs) - with open_dataset(tmp_file, engine='h5netcdf', **open_kwargs) as ds: + with open_dataset(tmp_file, engine='h5netcdf', + autoclose=self.autoclose, **open_kwargs) as ds: yield ds def test_orthogonal_indexing(self): @@ -992,6 +1032,80 @@ def test_encoding_unlimited_dims(self): with pytest.warns(UserWarning): ds.to_netcdf(tmp_file, engine='h5netcdf', unlimited_dims=['y']) +# tests pending h5netcdf fix +#class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): +# autoclose = True + +class OpenMFDatasetTest(TestCase): + autoclose = True + def validate_open_mfdataset_autoclose(self, engine, nfiles=10): + randdata = np.random.randn(nfiles) + original = Dataset({'foo': ('x', randdata)}) + # test standard open_mfdataset approach with too many files + with create_tmp_files(nfiles) as tmpfiles: + for readengine in engine: + writeengine = readengine if readengine != 'pynio' else 'netcdf4' + # split into multiple sets of temp files + for ii in original.x.values: + ( + original.isel(x=slice(ii, ii+1)) + ).to_netcdf(tmpfiles[ii], engine=writeengine) + + # check that calculation on opened datasets works properly + ds = open_mfdataset(tmpfiles, engine=readengine, + autoclose=self.autoclose) + self.assertAllClose(ds.x.sum().values, (nfiles*(nfiles-1))/2) + self.assertAllClose(ds.foo.sum().values, np.sum(randdata)) + self.assertAllClose(ds.sum().foo.values, np.sum(randdata)) + ds.close() + + def validate_open_mfdataset_large_num_files(self, engine): + self.validate_open_mfdataset_autoclose(engine, nfiles=2000) + + @requires_dask + @requires_netCDF4 + def test_1_autoclose_netcdf4(self): + self.validate_open_mfdataset_autoclose(engine=['netcdf4']) + + @requires_dask + @requires_scipy + def test_2_autoclose_scipy(self): + self.validate_open_mfdataset_autoclose(engine=['scipy']) + + @requires_dask + @requires_pynio + def test_3_autoclose_pynio(self): + self.validate_open_mfdataset_autoclose(engine=['pynio']) + + # use of autoclose=True with h5netcdf broken because of + # probable h5netcdf error, uncomment when fixed to test + #@requires_dask + #@requires_h5netcdf + #def test_4_autoclose_h5netcdf(self): + # self.validate_open_mfdataset_autoclose(engine=['h5netcdf']) + + @requires_dask + @requires_netCDF4 + def test_1_open_large_num_files_netcdf4(self): + self.validate_open_mfdataset_large_num_files(engine=['netcdf4']) + + @requires_dask + @requires_scipy + def test_2_open_large_num_files_scipy(self): + self.validate_open_mfdataset_large_num_files(engine=['scipy']) + + @requires_dask + @requires_pynio + def test_3_open_large_num_files_pynio(self): + self.validate_open_mfdataset_large_num_files(engine=['pynio']) + + # use of autoclose=True with h5netcdf broken because of + # probable h5netcdf error, uncomment when fixed to test + #@requires_dask + #@requires_h5netcdf + #def test_4_open_large_num_files_h5netcdf(self): + # self.validate_open_mfdataset_large_num_files(engine=['h5netcdf']) + @requires_dask @requires_scipy @@ -1030,17 +1144,18 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2]) as actual: + with open_mfdataset([tmp1, tmp2], autoclose=self.autoclose) as actual: self.assertIsInstance(actual.foo.variable.data, da.Array) self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),)) self.assertDatasetAllClose(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, + autoclose=self.autoclose) as actual: self.assertEqual(actual.foo.variable.data.chunks, ((3, 2, 3, 2),)) with self.assertRaisesRegexp(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc') + open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) def test_preprocess_mfdataset(self): original = Dataset({'foo': ('x', np.random.randn(10))}) @@ -1051,7 +1166,8 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess) as actual: + with open_mfdataset(tmp, preprocess=preprocess, + autoclose=self.autoclose) as actual: self.assertDatasetIdentical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -1061,7 +1177,8 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2]) as actual: + with open_mfdataset([tmp1, tmp2], + autoclose=self.autoclose) as actual: self.assertDatasetIdentical(actual, original) def test_save_mfdataset_invalid(self): @@ -1075,7 +1192,7 @@ def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp) as ds: + with open_mfdataset(tmp, autoclose=self.autoclose) as ds: actual = 1.0 * ds self.assertDatasetAllClose(original, actual) @@ -1085,7 +1202,8 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None, + autoclose=self.autoclose) as actual: self.assertDatasetIdentical(data, actual) def test_open_dataset(self): @@ -1118,10 +1236,10 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp) as ds: + with open_mfdataset(tmp, autoclose=self.autoclose) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp) as ds: + with open_mfdataset(tmp, autoclose=self.autoclose) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): @@ -1139,6 +1257,8 @@ def test_dataarray_compute(self): self.assertTrue(computed._in_memory) self.assertDataArrayAllClose(actual, computed) +class DaskTestAutocloseTrue(DaskTest): + autoclose=True @requires_scipy_or_netCDF4 @requires_pydap @@ -1197,7 +1317,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, with create_tmp_file( allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) - with open_dataset(tmp_file, engine='pynio', **open_kwargs) as ds: + with open_dataset(tmp_file, engine='pynio', + autoclose=self.autoclose, **open_kwargs) as ds: yield ds def test_weakrefs(self): @@ -1211,6 +1332,9 @@ def test_weakrefs(self): del on_disk # trigger garbage collection self.assertDatasetIdentical(actual, expected) +class TestPyNioAutocloseTrue(TestPyNio): + autoclose=True + class TestEncodingInvalid(TestCase):