From 10495be789d7362adb68a2133b440ffc6ca25e6e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 7 Nov 2017 13:31:44 -0500 Subject: [PATCH] Support Dask interface (#1674) * add dask interface to variable * redirect compute and visualize methods to dask * add dask interface to DataArray * add dask interface to Dataset Also test distributed computing * remove visualize method * support backwards compatibility * cleanup * style edits * change versions in tests to trigger on dask dev versions * support dask arrays in DataArray coordinates * remove commented assertion * whats new * elaborate on what's new --- doc/whats-new.rst | 7 ++ xarray/core/dataarray.py | 29 +++++++ xarray/core/dataset.py | 71 ++++++++++++++++ xarray/core/variable.py | 35 ++++++++ xarray/tests/test_dask.py | 137 ++++++++++++++++++++++++++----- xarray/tests/test_distributed.py | 34 +++++++- 6 files changed, 290 insertions(+), 23 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 99610640fe7..6eebf30815b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,9 +15,16 @@ What's New .. _whats-new.0.10.0: + +v0.10.0 (unreleased) +-------------------- + Changes since v0.10.0 rc1 (Unreleased) -------------------------------------- +- Experimental support for the Dask collection interface (:issue:`1674`). + By `Matthew Rocklin `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e8330ef6c77..1dac72335d2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -576,6 +576,35 @@ def reset_coords(self, names=None, drop=False, inplace=False): dataset[self.name] = self.variable return dataset + def __dask_graph__(self): + return self._to_temp_dataset().__dask_graph__() + + def __dask_keys__(self): + return self._to_temp_dataset().__dask_keys__() + + @property + def __dask_optimize__(self): + return self._to_temp_dataset().__dask_optimize__ + + @property + def __dask_scheduler__(self): + return self._to_temp_dataset().__dask_optimize__ + + def __dask_postcompute__(self): + func, args = self._to_temp_dataset().__dask_postcompute__() + return self._dask_finalize, (func, args, self.name) + + def __dask_postpersist__(self): + func, args = self._to_temp_dataset().__dask_postpersist__() + return self._dask_finalize, (func, args, self.name) + + @staticmethod + def _dask_finalize(results, func, args, name): + ds = func(results, *args) + variable = ds._variables.pop(_THIS_ARRAY) + coords = ds._variables + return DataArray(variable, coords, name=name, fastpath=True) + def load(self, **kwargs): """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index abe32055f97..56c9df0af93 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -493,6 +493,77 @@ def load(self, **kwargs): return self + def __dask_graph__(self): + graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} + graphs = {k: v for k, v in graphs.items() if v is not None} + if not graphs: + return None + else: + from dask import sharedict + return sharedict.merge(*graphs.values()) + + def __dask_keys__(self): + import dask + return [v.__dask_keys__() for v in self.variables.values() + if dask.is_dask_collection(v)] + + @property + def __dask_optimize__(self): + import dask.array as da + return da.Array.__dask_optimize__ + + @property + def __dask_scheduler__(self): + import dask.array as da + return da.Array.__dask_scheduler__ + + def __dask_postcompute__(self): + import dask + info = [(True, k, v.__dask_postcompute__()) + if dask.is_dask_collection(v) else + (False, k, v) for k, v in self._variables.items()] + return self._dask_postcompute, (info, self._coord_names, self._dims, + self._attrs, self._file_obj, + self._encoding) + + def __dask_postpersist__(self): + import dask + info = [(True, k, v.__dask_postpersist__()) + if dask.is_dask_collection(v) else + (False, k, v) for k, v in self._variables.items()] + return self._dask_postpersist, (info, self._coord_names, self._dims, + self._attrs, self._file_obj, + self._encoding) + + @staticmethod + def _dask_postcompute(results, info, *args): + variables = OrderedDict() + results2 = list(results[::-1]) + for is_dask, k, v in info: + if is_dask: + func, args2 = v + r = results2.pop() + result = func(r, *args2) + else: + result = v + variables[k] = result + + final = Dataset._construct_direct(variables, *args) + return final + + @staticmethod + def _dask_postpersist(dsk, info, *args): + variables = OrderedDict() + for is_dask, k, v in info: + if is_dask: + func, args2 = v + result = func(dsk, *args2) + else: + result = v + variables[k] = result + + return Dataset._construct_direct(variables, *args) + def compute(self, **kwargs): """Manually trigger loading of this dataset's data from disk or a remote source into memory and return a new dataset. The original is diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 14e2770879c..07f5d1b7da0 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -355,6 +355,41 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def __dask_graph__(self): + if isinstance(self._data, dask_array_type): + return self._data.__dask_graph__() + else: + return None + + def __dask_keys__(self): + return self._data.__dask_keys__() + + @property + def __dask_optimize__(self): + return self._data.__dask_optimize__ + + @property + def __dask_scheduler__(self): + return self._data.__dask_scheduler__ + + def __dask_postcompute__(self): + array_func, array_args = self._data.__dask_postcompute__() + return self._dask_finalize, (array_func, array_args, self._dims, + self._attrs, self._encoding) + + def __dask_postpersist__(self): + array_func, array_args = self._data.__dask_postpersist__() + return self._dask_finalize, (array_func, array_args, self._dims, + self._attrs, self._encoding) + + @staticmethod + def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): + if isinstance(results, dict): # persist case + name = array_args[0] + results = {k: v for k, v in results.items() if k[0] == name} # cull + data = array_func(results, *array_args) + return Variable(dims, data, attrs=attrs, encoding=encoding) + @property def values(self): """The variable's data as a numpy.ndarray""" diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 6122aaf4e11..4e9b0250a6a 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -206,6 +206,34 @@ def test_bivariate_ufunc(self): self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') + def test_compute(self): + u = self.eager_var + v = self.lazy_var + + assert dask.is_dask_collection(v) + (v2,) = dask.compute(v + 1) + assert not dask.is_dask_collection(v2) + + assert ((u + 1).data == v2.data).all() + + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') + def test_persist(self): + u = self.eager_var + v = self.lazy_var + 1 + + (v2,) = dask.persist(v) + assert v is not v2 + assert len(v2.__dask_graph__()) < len(v.__dask_graph__()) + assert v2.__dask_keys__() == v.__dask_keys__() + assert dask.is_dask_collection(v) + assert dask.is_dask_collection(v2) + + self.assertLazyAndAllClose(u + 1, v) + self.assertLazyAndAllClose(u + 1, v2) + class TestDataArrayAndDataset(DaskTestCase): def assertLazyAndIdentical(self, expected, actual): @@ -251,6 +279,34 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') + def test_compute(self): + u = self.eager_array + v = self.lazy_array + + assert dask.is_dask_collection(v) + (v2,) = dask.compute(v + 1) + assert not dask.is_dask_collection(v2) + + assert ((u + 1).data == v2.data).all() + + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') + def test_persist(self): + u = self.eager_array + v = self.lazy_array + 1 + + (v2,) = dask.persist(v) + assert v is not v2 + assert len(v2.__dask_graph__()) < len(v.__dask_graph__()) + assert v2.__dask_keys__() == v.__dask_keys__() + assert dask.is_dask_collection(v) + assert dask.is_dask_collection(v2) + + self.assertLazyAndAllClose(u + 1, v) + self.assertLazyAndAllClose(u + 1, v2) + def test_concat_loads_variables(self): # Test that concat() computes not-in-memory variables at most once # and loads them in the output, while leaving the input unaltered. @@ -402,28 +458,6 @@ def counting_get(*args, **kwargs): ds.load() self.assertEqual(count[0], 1) - def test_persist_Dataset(self): - ds = Dataset({'foo': ('x', range(5)), - 'bar': ('x', range(5))}).chunk() - ds = ds + 1 - n = len(ds.foo.data.dask) - - ds2 = ds.persist() - - assert len(ds2.foo.data.dask) == 1 - assert len(ds.foo.data.dask) == n # doesn't mutate in place - - def test_persist_DataArray(self): - x = da.arange(10, chunks=(5,)) - y = DataArray(x) - z = y + 1 - n = len(z.data.dask) - - zz = z.persist() - - assert len(z.data.dask) == n - assert len(zz.data.dask) == zz.data.npartitions - def test_stack(self): data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4)) arr = DataArray(data, dims=('w', 'x', 'y')) @@ -737,3 +771,62 @@ def build_dask_array(name): return dask.array.Array( dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64) + + +# test both the perist method and the dask.persist function +# the dask.persist function requires a new version of dask +@pytest.mark.parametrize('persist', [ + lambda x: x.persist(), + pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + lambda x: dask.persist(x)[0], + reason='Need Dask 0.16') +]) +def test_persist_Dataset(persist): + ds = Dataset({'foo': ('x', range(5)), + 'bar': ('x', range(5))}).chunk() + ds = ds + 1 + n = len(ds.foo.data.dask) + + ds2 = persist(ds) + + assert len(ds2.foo.data.dask) == 1 + assert len(ds.foo.data.dask) == n # doesn't mutate in place + + +@pytest.mark.parametrize('persist', [ + lambda x: x.persist(), + pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + lambda x: dask.persist(x)[0], + reason='Need Dask 0.16') +]) +def test_persist_DataArray(persist): + x = da.arange(10, chunks=(5,)) + y = DataArray(x) + z = y + 1 + n = len(z.data.dask) + + zz = persist(z) + + assert len(z.data.dask) == n + assert len(zz.data.dask) == zz.data.npartitions + + +@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') +def test_dataarray_with_dask_coords(): + import toolz + x = xr.Variable('x', da.arange(8, chunks=(4,))) + y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2) + data = da.random.random((8, 8), chunks=(4, 4)) + 1 + array = xr.DataArray(data, dims=['x', 'y']) + array.coords['xx'] = x + array.coords['yy'] = y + + assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(), + x.__dask_graph__(), + y.__dask_graph__()) + + (array2,) = dask.compute(array) + assert not dask.is_dask_collection(array2) + + assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values()) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 1868486b01f..9999ed9a669 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -4,7 +4,9 @@ distributed = pytest.importorskip('distributed') da = pytest.importorskip('dask.array') -from distributed.utils_test import cluster, loop +import dask +from distributed.utils_test import cluster, loop, gen_cluster +from distributed.client import futures_of, wait from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS from xarray.tests.test_dataset import create_test_data @@ -32,3 +34,33 @@ def test_dask_distributed_integration_test(loop, engine): assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) + + +@pytest.mark.skipif(distributed.__version__ <= '1.19.3', + reason='Need recent distributed version to clean up get') +@gen_cluster(client=True, timeout=None) +def test_async(c, s, a, b): + x = create_test_data() + assert not dask.is_dask_collection(x) + y = x.chunk({'dim2': 4}) + 10 + assert dask.is_dask_collection(y) + assert dask.is_dask_collection(y.var1) + assert dask.is_dask_collection(y.var2) + + z = y.persist() + assert str(z) + + assert dask.is_dask_collection(z) + assert dask.is_dask_collection(z.var1) + assert dask.is_dask_collection(z.var2) + assert len(y.__dask_graph__()) > len(z.__dask_graph__()) + + assert not futures_of(y) + assert futures_of(z) + + future = c.compute(z) + w = yield future + assert not dask.is_dask_collection(w) + assert_allclose(x + 10, w) + + assert s.task_state