Skip to content

Commit

Permalink
Support Dask interface (#1674)
Browse files Browse the repository at this point in the history
* add dask interface to variable

* redirect compute and visualize methods to dask

* add dask interface to DataArray

* add dask interface to Dataset

Also test distributed computing

* remove visualize method

* support backwards compatibility

* cleanup

* style edits

* change versions in tests to trigger on dask dev versions

* support dask arrays in DataArray coordinates

* remove commented assertion

* whats new

* elaborate on what's new
  • Loading branch information
mrocklin authored and shoyer committed Nov 7, 2017
1 parent 2a1d392 commit 10495be
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 23 deletions.
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ What's New
.. _whats-new.0.10.0:


v0.10.0 (unreleased)
--------------------

Changes since v0.10.0 rc1 (Unreleased)
--------------------------------------

- Experimental support for the Dask collection interface (:issue:`1674`).
By `Matthew Rocklin <https://github.com/mrocklin>`_.

Bug fixes
~~~~~~~~~

Expand Down
29 changes: 29 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,35 @@ def reset_coords(self, names=None, drop=False, inplace=False):
dataset[self.name] = self.variable
return dataset

def __dask_graph__(self):
return self._to_temp_dataset().__dask_graph__()

def __dask_keys__(self):
return self._to_temp_dataset().__dask_keys__()

@property
def __dask_optimize__(self):
return self._to_temp_dataset().__dask_optimize__

@property
def __dask_scheduler__(self):
return self._to_temp_dataset().__dask_optimize__

def __dask_postcompute__(self):
func, args = self._to_temp_dataset().__dask_postcompute__()
return self._dask_finalize, (func, args, self.name)

def __dask_postpersist__(self):
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (func, args, self.name)

@staticmethod
def _dask_finalize(results, func, args, name):
ds = func(results, *args)
variable = ds._variables.pop(_THIS_ARRAY)
coords = ds._variables
return DataArray(variable, coords, name=name, fastpath=True)

def load(self, **kwargs):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return this array.
Expand Down
71 changes: 71 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,77 @@ def load(self, **kwargs):

return self

def __dask_graph__(self):
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
graphs = {k: v for k, v in graphs.items() if v is not None}
if not graphs:
return None
else:
from dask import sharedict
return sharedict.merge(*graphs.values())

def __dask_keys__(self):
import dask
return [v.__dask_keys__() for v in self.variables.values()
if dask.is_dask_collection(v)]

@property
def __dask_optimize__(self):
import dask.array as da
return da.Array.__dask_optimize__

@property
def __dask_scheduler__(self):
import dask.array as da
return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
import dask
info = [(True, k, v.__dask_postcompute__())
if dask.is_dask_collection(v) else
(False, k, v) for k, v in self._variables.items()]
return self._dask_postcompute, (info, self._coord_names, self._dims,
self._attrs, self._file_obj,
self._encoding)

def __dask_postpersist__(self):
import dask
info = [(True, k, v.__dask_postpersist__())
if dask.is_dask_collection(v) else
(False, k, v) for k, v in self._variables.items()]
return self._dask_postpersist, (info, self._coord_names, self._dims,
self._attrs, self._file_obj,
self._encoding)

@staticmethod
def _dask_postcompute(results, info, *args):
variables = OrderedDict()
results2 = list(results[::-1])
for is_dask, k, v in info:
if is_dask:
func, args2 = v
r = results2.pop()
result = func(r, *args2)
else:
result = v
variables[k] = result

final = Dataset._construct_direct(variables, *args)
return final

@staticmethod
def _dask_postpersist(dsk, info, *args):
variables = OrderedDict()
for is_dask, k, v in info:
if is_dask:
func, args2 = v
result = func(dsk, *args2)
else:
result = v
variables[k] = result

return Dataset._construct_direct(variables, *args)

def compute(self, **kwargs):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
Expand Down
35 changes: 35 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,41 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)

def __dask_graph__(self):
if isinstance(self._data, dask_array_type):
return self._data.__dask_graph__()
else:
return None

def __dask_keys__(self):
return self._data.__dask_keys__()

@property
def __dask_optimize__(self):
return self._data.__dask_optimize__

@property
def __dask_scheduler__(self):
return self._data.__dask_scheduler__

def __dask_postcompute__(self):
array_func, array_args = self._data.__dask_postcompute__()
return self._dask_finalize, (array_func, array_args, self._dims,
self._attrs, self._encoding)

def __dask_postpersist__(self):
array_func, array_args = self._data.__dask_postpersist__()
return self._dask_finalize, (array_func, array_args, self._dims,
self._attrs, self._encoding)

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
if isinstance(results, dict): # persist case
name = array_args[0]
results = {k: v for k, v in results.items() if k[0] == name} # cull
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)

@property
def values(self):
"""The variable's data as a numpy.ndarray"""
Expand Down
137 changes: 115 additions & 22 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,34 @@ def test_bivariate_ufunc(self):
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_compute(self):
u = self.eager_var
v = self.lazy_var

assert dask.is_dask_collection(v)
(v2,) = dask.compute(v + 1)
assert not dask.is_dask_collection(v2)

assert ((u + 1).data == v2.data).all()

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_persist(self):
u = self.eager_var
v = self.lazy_var + 1

(v2,) = dask.persist(v)
assert v is not v2
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
assert v2.__dask_keys__() == v.__dask_keys__()
assert dask.is_dask_collection(v)
assert dask.is_dask_collection(v2)

self.assertLazyAndAllClose(u + 1, v)
self.assertLazyAndAllClose(u + 1, v2)


class TestDataArrayAndDataset(DaskTestCase):
def assertLazyAndIdentical(self, expected, actual):
Expand Down Expand Up @@ -251,6 +279,34 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_compute(self):
u = self.eager_array
v = self.lazy_array

assert dask.is_dask_collection(v)
(v2,) = dask.compute(v + 1)
assert not dask.is_dask_collection(v2)

assert ((u + 1).data == v2.data).all()

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_persist(self):
u = self.eager_array
v = self.lazy_array + 1

(v2,) = dask.persist(v)
assert v is not v2
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
assert v2.__dask_keys__() == v.__dask_keys__()
assert dask.is_dask_collection(v)
assert dask.is_dask_collection(v2)

self.assertLazyAndAllClose(u + 1, v)
self.assertLazyAndAllClose(u + 1, v2)

def test_concat_loads_variables(self):
# Test that concat() computes not-in-memory variables at most once
# and loads them in the output, while leaving the input unaltered.
Expand Down Expand Up @@ -402,28 +458,6 @@ def counting_get(*args, **kwargs):
ds.load()
self.assertEqual(count[0], 1)

def test_persist_Dataset(self):
ds = Dataset({'foo': ('x', range(5)),
'bar': ('x', range(5))}).chunk()
ds = ds + 1
n = len(ds.foo.data.dask)

ds2 = ds.persist()

assert len(ds2.foo.data.dask) == 1
assert len(ds.foo.data.dask) == n # doesn't mutate in place

def test_persist_DataArray(self):
x = da.arange(10, chunks=(5,))
y = DataArray(x)
z = y + 1
n = len(z.data.dask)

zz = z.persist()

assert len(z.data.dask) == n
assert len(zz.data.dask) == zz.data.npartitions

def test_stack(self):
data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4))
arr = DataArray(data, dims=('w', 'x', 'y'))
Expand Down Expand Up @@ -737,3 +771,62 @@ def build_dask_array(name):
return dask.array.Array(
dask={(name, 0): (kernel, name)}, name=name,
chunks=((1,),), dtype=np.int64)


# test both the perist method and the dask.persist function
# the dask.persist function requires a new version of dask
@pytest.mark.parametrize('persist', [
lambda x: x.persist(),
pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
lambda x: dask.persist(x)[0],
reason='Need Dask 0.16')
])
def test_persist_Dataset(persist):
ds = Dataset({'foo': ('x', range(5)),
'bar': ('x', range(5))}).chunk()
ds = ds + 1
n = len(ds.foo.data.dask)

ds2 = persist(ds)

assert len(ds2.foo.data.dask) == 1
assert len(ds.foo.data.dask) == n # doesn't mutate in place


@pytest.mark.parametrize('persist', [
lambda x: x.persist(),
pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
lambda x: dask.persist(x)[0],
reason='Need Dask 0.16')
])
def test_persist_DataArray(persist):
x = da.arange(10, chunks=(5,))
y = DataArray(x)
z = y + 1
n = len(z.data.dask)

zz = persist(z)

assert len(z.data.dask) == n
assert len(zz.data.dask) == zz.data.npartitions


@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_dataarray_with_dask_coords():
import toolz
x = xr.Variable('x', da.arange(8, chunks=(4,)))
y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2)
data = da.random.random((8, 8), chunks=(4, 4)) + 1
array = xr.DataArray(data, dims=['x', 'y'])
array.coords['xx'] = x
array.coords['yy'] = y

assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(),
x.__dask_graph__(),
y.__dask_graph__())

(array2,) = dask.compute(array)
assert not dask.is_dask_collection(array2)

assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values())
34 changes: 33 additions & 1 deletion xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

distributed = pytest.importorskip('distributed')
da = pytest.importorskip('dask.array')
from distributed.utils_test import cluster, loop
import dask
from distributed.utils_test import cluster, loop, gen_cluster
from distributed.client import futures_of, wait

from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS
from xarray.tests.test_dataset import create_test_data
Expand Down Expand Up @@ -32,3 +34,33 @@ def test_dask_distributed_integration_test(loop, engine):
assert isinstance(restored.var1.data, da.Array)
computed = restored.compute()
assert_allclose(original, computed)


@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
reason='Need recent distributed version to clean up get')
@gen_cluster(client=True, timeout=None)
def test_async(c, s, a, b):
x = create_test_data()
assert not dask.is_dask_collection(x)
y = x.chunk({'dim2': 4}) + 10
assert dask.is_dask_collection(y)
assert dask.is_dask_collection(y.var1)
assert dask.is_dask_collection(y.var2)

z = y.persist()
assert str(z)

assert dask.is_dask_collection(z)
assert dask.is_dask_collection(z.var1)
assert dask.is_dask_collection(z.var2)
assert len(y.__dask_graph__()) > len(z.__dask_graph__())

assert not futures_of(y)
assert futures_of(z)

future = c.compute(z)
w = yield future
assert not dask.is_dask_collection(w)
assert_allclose(x + 10, w)

assert s.task_state

0 comments on commit 10495be

Please sign in to comment.