Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disabled auto-caching on dask; new .compute() method #1018

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Breaking changes
merges will now succeed in cases that previously raised
``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the
previous default.
- Pickling an xarray object based on the dask backend, or reading its
:py:meth:`values` property, won't automatically convert the array from dask
to numpy in the original object anymore.
By `Guido Imperiale <https://github.com/crusaderky>`_.

Deprecations
~~~~~~~~~~~~
Expand All @@ -45,33 +49,33 @@ Deprecations
Enhancements
~~~~~~~~~~~~
- Add checking of ``attr`` names and values when saving to netCDF, raising useful
error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.

error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.
- Added ability to save ``DataArray`` objects directly to netCDF files using
:py:meth:`~xarray.DataArray.to_netcdf`, and to load directly from netCDF files
using :py:func:`~xarray.open_dataarray` (:issue:`915`). These remove the need
to convert a ``DataArray`` to a ``Dataset`` before saving as a netCDF file,
and deals with names to ensure a perfect 'roundtrip' capability.
By `Robin Wilson <https://github.com/robintw`_.

- Multi-index levels are now accessible as "virtual" coordinate variables,
e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index
(see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels
as keyword arguments, e.g., ``ds.sel(time='2000-01')``
(see :ref:`multi-level indexing`).
By `Benoit Bovy <https://github.com/benbovy>`_.

- Added the ``compat`` option ``'no_conflicts'`` to ``merge``, allowing the
combination of xarray objects with disjoint (:issue:`742`) or
overlapping (:issue:`835`) coordinates as long as all present data agrees.
By `Johnnie Gray <https://github.com/jcmgray>`_. See
:ref:`combining.no_conflicts` for more details.

- It is now possible to set ``concat_dim=None`` explicitly in
:py:func:`~xarray.open_mfdataset` to disable inferring a dimension along
which to concatenate.
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Added methods :py:meth:`DataArray.compute`, :py:meth:`Dataset.compute`, and
:py:meth:`Variable.compute` as a non-mutating alternative to
:py:meth:`~DataArray.load`.
By `Guido Imperiale <https://github.com/crusaderky>`_.

Bug fixes
~~~~~~~~~
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ def load(self):
self._coords = new._coords
return self

def compute(self):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

def copy(self, deep=True):
"""Returns a copy of this array.

Expand Down
32 changes: 23 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,11 @@ def load_store(cls, store, decoder=None):
return obj

def __getstate__(self):
"""Always load data in-memory before pickling"""
self.load()
"""Load data in-memory before pickling (except for Dask data)"""
for v in self.variables.values():
if not isinstance(v.data, dask_array_type):
v.load()

# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
state = self.__dict__.copy()
Expand Down Expand Up @@ -319,6 +322,19 @@ def load(self):

return self

def compute(self):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
file_obj=None):
Expand Down Expand Up @@ -401,14 +417,12 @@ def copy(self, deep=False):
"""Returns a copy of this dataset.

If `deep=True`, a deep copy is made of each of the component variables.
Otherwise, a shallow copy is made, so each variable in the new dataset
is also a variable in the original dataset.
Otherwise, a shallow copy of each of the component variable is made, so
that the underlying memory region of the new dataset is the same as in
the original dataset.
"""
if deep:
variables = OrderedDict((k, v.copy(deep=True))
for k, v in iteritems(self._variables))
else:
variables = self._variables.copy()
variables = OrderedDict((k, v.copy(deep=deep))
for k, v in iteritems(self._variables))
# skip __init__ to avoid costly validation
return self._construct_direct(variables, self._coord_names.copy(),
self._dims.copy(), self._attrs_copy())
Expand Down
45 changes: 35 additions & 10 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,21 @@ def data(self, data):
"replacement data must match the Variable's shape")
self._data = data

def _data_cast(self):
if isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
return self._data
else:
return np.asarray(self._data)

def _data_cached(self):
if not isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
self._data = np.asarray(self._data)
return self._data
"""Load data into memory and return it.
Do not cache dask arrays automatically; that should
require an explicit load() call.
"""
new_data = self._data_cast()
if not isinstance(self._data, dask_array_type):
self._data = new_data
return new_data

@property
def _indexable_data(self):
Expand All @@ -290,12 +301,25 @@ def load(self):
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.
"""
self._data_cached()
"""
self._data = self._data_cast()
return self

def compute(self):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return a new variable. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.
"""
new = self.copy(deep=False)
return new.load()

def __getstate__(self):
"""Always cache data as an in-memory array before pickling"""
"""Always cache data as an in-memory array before pickling
(with the exception of dask backend)"""
self._data_cached()
# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
Expand Down Expand Up @@ -1093,10 +1117,11 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
raise ValueError('%s objects must be 1-dimensional' %
type(self).__name__)

def _data_cached(self):
if not isinstance(self._data, PandasIndexAdapter):
self._data = PandasIndexAdapter(self._data)
return self._data
def _data_cast(self):
if isinstance(self._data, PandasIndexAdapter):
return self._data
else:
return PandasIndexAdapter(self._data)

def __getitem__(self, key):
key = self._item_key_to_tuple(key)
Expand Down
62 changes: 48 additions & 14 deletions xarray/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,24 @@ def assert_loads(vars=None):
actual = ds.load()
self.assertDatasetAllClose(expected, actual)

def test_dataset_compute(self):
expected = create_test_data()

with self.roundtrip(expected) as actual:
# Test Dataset.compute()
for v in actual.variables.values():
self.assertFalse(v._in_memory)

computed = actual.compute()

for v in actual.data_vars.values():
self.assertFalse(v._in_memory)
for v in computed.variables.values():
self.assertTrue(v._in_memory)

self.assertDatasetAllClose(expected, actual)
self.assertDatasetAllClose(expected, computed)

def test_roundtrip_None_variable(self):
expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])})
with self.roundtrip(expected) as actual:
Expand Down Expand Up @@ -230,18 +248,6 @@ def test_roundtrip_coordinates(self):
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

expected = original.copy()
expected.attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

expected = original.copy(deep=True)
expected['foo'].attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

def test_roundtrip_boolean_dtype(self):
original = create_boolean_data()
self.assertEqual(original['x'].dtype, 'bool')
Expand Down Expand Up @@ -872,7 +878,26 @@ def test_read_byte_attrs_as_unicode(self):
@requires_dask
@requires_scipy
@requires_netCDF4
class DaskTest(TestCase):
class DaskTest(TestCase, DatasetIOTestCases):
@contextlib.contextmanager
def create_store(self):
yield Dataset()

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={}):
yield data.chunk()

def test_roundtrip_datetime_data(self):
# Override method in DatasetIOTestCases - remove not applicable save_kwds
times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT'])
expected = Dataset({'t': ('t', times), 't0': times[0]})
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

def test_write_store(self):
# Override method in DatasetIOTestCases - not applicable to dask
pass

def test_open_mfdataset(self):
original = Dataset({'foo': ('x', np.random.randn(10))})
with create_tmp_file() as tmp1:
Expand Down Expand Up @@ -992,7 +1017,16 @@ def test_deterministic_names(self):
self.assertIn(tmp, dask_name)
self.assertEqual(original_names, repeat_names)


def test_dataarray_compute(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods would belong better in test_dask.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should move test_dataarray_compute() to a different module as there's test_dataset_compute() that applies to all backends.
Moving the pickle and values tests.

# Test DataArray.compute() on dask backend.
# The test for Dataset.compute() is already in DatasetIOTestCases;
# however dask is the only tested backend which supports DataArrays
actual = DataArray([1,2]).chunk()
computed = actual.compute()
self.assertFalse(actual._in_memory)
self.assertTrue(computed._in_memory)
self.assertDataArrayAllClose(actual, computed)

@requires_scipy_or_netCDF4
@requires_pydap
class PydapTest(TestCase):
Expand Down
Loading