Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable automatic cache with dask #1024

Merged
merged 16 commits into from
Nov 14, 2016
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ Breaking changes
merges will now succeed in cases that previously raised
``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the
previous default.
- Pickling an xarray object based on the dask backend, or reading its
:py:meth:`values` property, won't automatically convert the array from dask
to numpy in the original object anymore.
If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or
:py:class:`~xarray.Dataset`, its values won't be automatically cached, likely
causing performance degradation.
By `Guido Imperiale <https://github.com/crusaderky>`_.

Deprecations
~~~~~~~~~~~~
Expand All @@ -47,32 +54,31 @@ Enhancements
- Add checking of ``attr`` names and values when saving to netCDF, raising useful
error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.

- Added ability to save ``DataArray`` objects directly to netCDF files using
:py:meth:`~xarray.DataArray.to_netcdf`, and to load directly from netCDF files
using :py:func:`~xarray.open_dataarray` (:issue:`915`). These remove the need
to convert a ``DataArray`` to a ``Dataset`` before saving as a netCDF file,
and deals with names to ensure a perfect 'roundtrip' capability.
By `Robin Wilson <https://github.com/robintw>`_.

- Multi-index levels are now accessible as "virtual" coordinate variables,
e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index
(see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels
as keyword arguments, e.g., ``ds.sel(time='2000-01')``
(see :ref:`multi-level indexing`).
By `Benoit Bovy <https://github.com/benbovy>`_.

- Added the ``compat`` option ``'no_conflicts'`` to ``merge``, allowing the
combination of xarray objects with disjoint (:issue:`742`) or
overlapping (:issue:`835`) coordinates as long as all present data agrees.
By `Johnnie Gray <https://github.com/jcmgray>`_. See
:ref:`combining.no_conflicts` for more details.

- It is now possible to set ``concat_dim=None`` explicitly in
:py:func:`~xarray.open_mfdataset` to disable inferring a dimension along
which to concatenate.
By `Stephan Hoyer <https://github.com/shoyer>`_.

- Added methods :py:meth:`DataArray.compute`, :py:meth:`Dataset.compute`, and
:py:meth:`Variable.compute` as a non-mutating alternative to
:py:meth:`~DataArray.load`.
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Adds DataArray and Dataset methods :py:meth:`~xarray.DataArray.cumsum` and
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
<https://github.com/pwolfram>`_.
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ def load(self):
self._coords = new._coords
return self

def compute(self):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

def copy(self, deep=True):
"""Returns a copy of this array.

Expand Down
43 changes: 33 additions & 10 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,11 @@ def load_store(cls, store, decoder=None):
return obj

def __getstate__(self):
"""Always load data in-memory before pickling"""
self.load()
"""Load data in-memory before pickling (except for Dask data)"""
for v in self.variables.values():
if not isinstance(v.data, dask_array_type):
v.load()

# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
state = self.__dict__.copy()
Expand Down Expand Up @@ -319,6 +322,19 @@ def load(self):

return self

def compute(self):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
file_obj=None):
Expand Down Expand Up @@ -401,14 +417,12 @@ def copy(self, deep=False):
"""Returns a copy of this dataset.

If `deep=True`, a deep copy is made of each of the component variables.
Otherwise, a shallow copy is made, so each variable in the new dataset
is also a variable in the original dataset.
Otherwise, a shallow copy of each of the component variable is made, so
that the underlying memory region of the new dataset is the same as in
the original dataset.
"""
if deep:
variables = OrderedDict((k, v.copy(deep=True))
for k, v in iteritems(self._variables))
else:
variables = self._variables.copy()
variables = OrderedDict((k, v.copy(deep=deep))
for k, v in iteritems(self._variables))
# skip __init__ to avoid costly validation
return self._construct_direct(variables, self._coord_names.copy(),
self._dims.copy(), self._attrs_copy())
Expand Down Expand Up @@ -792,13 +806,19 @@ def chunks(self):
array.
"""
chunks = {}
for v in self.variables.values():
for v in self.data_vars.values():
Copy link
Member

@shoyer shoyer Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned about skipping non-data_vars here. Coordinates could still be chunked, e.g., if they were loaded from a file, or created directly from dask arrays.

if v.chunks is not None:
new_chunks = list(zip(v.dims, v.chunks))
if any(chunk != chunks[d] for d, chunk in new_chunks
if d in chunks):
raise ValueError('inconsistent chunks')
chunks.update(new_chunks)
if chunks:
Copy link
Member

@shoyer shoyer Oct 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should this need chunks to not be empty already? That seems strange (maybe backwards) to me.

I might simply make this:

for dim, size in self.dims.items():
    if dim not in chunks:
        chunks[dim] = (size,)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if none of the data_vars use the dask backend, then you want chunks to return None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this method is inconsistent with Variable.chunks, but it currently always returns a dict.

I would either skip this change or use something like my version.

# Add dims that are defined in the coords but are not in data_vars
for v in self.coords.values():
for dim in v.dims:
if dim not in chunks:
chunks[dim] = (v.size,)
return Frozen(SortedKeysDict(chunks))

def chunk(self, chunks=None, name_prefix='xarray-', token=None,
Expand Down Expand Up @@ -851,6 +871,9 @@ def selkeys(dict_, keys):
return dict((d, dict_[d]) for d in keys if d in dict_)

def maybe_chunk(name, var, chunks):
if name not in self.data_vars:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point about performance, but I think that mostly holds true for indexes. So I would be inclined to adjust this to only skip variables in self.dims (aka indexes used for alignment).

I am still concerned about skipping coords if they are already dask arrays. If they are already dask arrays, then .chunk() should probably adjust their chunks anyways.

return var

chunks = selkeys(chunks, var.dims)
if not chunks:
chunks = None
Expand Down
46 changes: 36 additions & 10 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,21 @@ def data(self, data):
"replacement data must match the Variable's shape")
self._data = data

def _data_cast(self):
if isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this branch not also apply to dask_array_type?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, if you manually create a Variable with a dask array you'll get a LazilyIndexedArray at this point. Should this not also be kept unchanged?

return self._data
else:
return np.asarray(self._data)

def _data_cached(self):
if not isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
self._data = np.asarray(self._data)
return self._data
"""Load data into memory and return it.
Do not cache dask arrays automatically; that should
require an explicit load() call.
"""
new_data = self._data_cast()
if not isinstance(self._data, dask_array_type):
self._data = new_data
return new_data

@property
def _indexable_data(self):
Expand All @@ -291,12 +302,26 @@ def load(self):
because all xarray functions should either work on deferred data or
load data automatically.
"""
self._data_cached()
self._data = self._data_cast()
return self

def compute(self):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return a new variable. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.
"""
new = self.copy(deep=False)
return new.load()

def __getstate__(self):
"""Always cache data as an in-memory array before pickling"""
self._data_cached()
"""Always cache data as an in-memory array before pickling
(with the exception of dask backend)"""
if not isinstance(self._data, dask_array_type):
self._data_cached()
# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
return self.__dict__
Expand Down Expand Up @@ -1102,10 +1127,11 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
raise ValueError('%s objects must be 1-dimensional' %
type(self).__name__)

def _data_cached(self):
if not isinstance(self._data, PandasIndexAdapter):
self._data = PandasIndexAdapter(self._data)
return self._data
def _data_cast(self):
if isinstance(self._data, PandasIndexAdapter):
return self._data
else:
return PandasIndexAdapter(self._data)

def __getitem__(self, key):
key = self._item_key_to_tuple(key)
Expand Down
62 changes: 48 additions & 14 deletions xarray/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def assert_loads(vars=None):
if vars is None:
vars = expected
with self.roundtrip(expected) as actual:
for v in actual.variables.values():
for v in actual.data_vars.values():
self.assertFalse(v._in_memory)
yield actual
for k, v in actual.variables.items():
Expand All @@ -149,6 +149,24 @@ def assert_loads(vars=None):
actual = ds.load()
self.assertDatasetAllClose(expected, actual)

def test_dataset_compute(self):
expected = create_test_data()

with self.roundtrip(expected) as actual:
# Test Dataset.compute()
for v in actual.data_vars.values():
self.assertFalse(v._in_memory)

computed = actual.compute()

for v in actual.data_vars.values():
self.assertFalse(v._in_memory)
for v in computed.variables.values():
self.assertTrue(v._in_memory)

self.assertDatasetAllClose(expected, actual)
self.assertDatasetAllClose(expected, computed)

def test_roundtrip_None_variable(self):
expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])})
with self.roundtrip(expected) as actual:
Expand Down Expand Up @@ -230,18 +248,6 @@ def test_roundtrip_coordinates(self):
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

expected = original.copy()
expected.attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

expected = original.copy(deep=True)
expected['foo'].attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

def test_roundtrip_boolean_dtype(self):
original = create_boolean_data()
self.assertEqual(original['x'].dtype, 'bool')
Expand Down Expand Up @@ -872,7 +878,26 @@ def test_read_byte_attrs_as_unicode(self):
@requires_dask
@requires_scipy
@requires_netCDF4
class DaskTest(TestCase):
class DaskTest(TestCase, DatasetIOTestCases):
@contextlib.contextmanager
def create_store(self):
yield Dataset()

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={}):
yield data.chunk()

def test_roundtrip_datetime_data(self):
# Override method in DatasetIOTestCases - remove not applicable save_kwds
times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT'])
expected = Dataset({'t': ('t', times), 't0': times[0]})
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

def test_write_store(self):
# Override method in DatasetIOTestCases - not applicable to dask
pass

def test_open_mfdataset(self):
original = Dataset({'foo': ('x', np.random.randn(10))})
with create_tmp_file() as tmp1:
Expand Down Expand Up @@ -992,6 +1017,15 @@ def test_deterministic_names(self):
self.assertIn(tmp, dask_name)
self.assertEqual(original_names, repeat_names)

def test_dataarray_compute(self):
# Test DataArray.compute() on dask backend.
# The test for Dataset.compute() is already in DatasetIOTestCases;
# however dask is the only tested backend which supports DataArrays
actual = DataArray([1,2]).chunk()
computed = actual.compute()
self.assertFalse(actual._in_memory)
self.assertTrue(computed._in_memory)
self.assertDataArrayAllClose(actual, computed)

@requires_scipy_or_netCDF4
@requires_pydap
Expand Down
Loading