Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Dask interface #1674

Merged
merged 15 commits into from
Nov 7, 2017
39 changes: 37 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,43 @@ def compute(self, **kwargs):
--------
dask.array.compute
"""
new = self.copy(deep=False)
return new.load(**kwargs)
import dask
(result,) = dask.compute(self, **kwargs)
return result

def visualize(self, **kwargs):
import dask
return dask.visualize(self, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My inclination would be to leave this out and require using dask.visualize(). My concern is that it could be easily confused with .plot().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


def __dask_graph__(self):
if isinstance(self._data, dask_array_type):
return self._data.__dask_graph__()
else:
return None

def __dask_keys__(self):
return self._data.__dask_keys__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK if these methods error (with AttributeError) when self._data is not a dask array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we always check if the object is a dask collection first by calling __dask_graph__


@property
def __dask_optimize__(self):
return self._data.__dask_optimize__

@property
def __dask_scheduler__(self):
return self._data.__dask_scheduler__

def __dask_postcompute__(self):
array_func, array_args = self._data.__dask_postcompute__()
return self._dask_finalize, (array_func, array_args, self._dims, self._attrs, self._encoding)

def __dask_postpersist__(self):
array_func, array_args = self._data.__dask_postpersist__()
return self._dask_finalize, (array_func, array_args, self._dims, self._attrs, self._encoding)

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)

@property
def values(self):
Expand Down
24 changes: 24 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,30 @@ def test_bivariate_ufunc(self):
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))

def test_compute(self):
u = self.eager_var
v = self.lazy_var

assert dask.is_dask_collection(v)
(v2,) = dask.compute(v + 1)
assert not dask.is_dask_collection(v2)

assert ((u + 1).data == v2.data).all()

def test_persist(self):
u = self.eager_var
v = self.lazy_var + 1

(v2,) = dask.persist(v)
assert v is not v2
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
assert v2.__dask_keys__() == v.__dask_keys__()
assert dask.is_dask_collection(v)
assert dask.is_dask_collection(v2)

self.assertLazyAndAllClose(u + 1, v)
self.assertLazyAndAllClose(u + 1, v2)


class TestDataArrayAndDataset(DaskTestCase):
def assertLazyAndIdentical(self, expected, actual):
Expand Down