From 91edf8ba0d7a9be9bb6cc73b7f9da0f5fe851c35 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 1 Nov 2017 12:36:56 -0400 Subject: [PATCH] support backwards compatibility --- xarray/core/dataarray.py | 10 ++++------ xarray/core/dataset.py | 5 ++--- xarray/core/variable.py | 5 ++--- xarray/tests/test_distributed.py | 2 +- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 588d3fe4c0d..1d81d4e4ba6 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -646,9 +646,8 @@ def compute(self, **kwargs): -------- dask.array.compute """ - import dask - (result,) = dask.compute(self, **kwargs) - return result + new = self.copy(deep=False) + return new.load(**kwargs) def persist(self, **kwargs): """ Trigger computation in constituent dask arrays @@ -666,9 +665,8 @@ def persist(self, **kwargs): -------- dask.persist """ - import dask - (result,) = dask.persist(self, **kwargs) - return result + ds = self._to_temp_dataset().persist(**kwargs) + return self._from_temp_dataset(ds) def copy(self, deep=True): """Returns a copy of this array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 42c7058f45d..6f76da14255 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -618,9 +618,8 @@ def persist(self, **kwargs): -------- dask.persist """ - import dask - (result,) = dask.persist(self, **kwargs) - return result + new = self.copy(deep=False) + return new._persist_inplace(**kwargs) @classmethod def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 750f5dc352e..dee17750281 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -360,9 +360,8 @@ def compute(self, **kwargs): -------- dask.array.compute """ - import dask - (result,) = dask.compute(self, **kwargs) - return result + new = self.copy(deep=False) + return new.load(**kwargs) def __dask_graph__(self): if isinstance(self._data, dask_array_type): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 127a87f176b..2c17b86e81e 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -36,6 +36,7 @@ def test_dask_distributed_integration_test(loop, engine): assert_allclose(original, computed) +@pytest.mark.skipif(distributed.__version__ <= '1.19.3') @gen_cluster(client=True, timeout=None) def test_async(c, s, a, b): x = create_test_data() @@ -63,5 +64,4 @@ def test_async(c, s, a, b): assert not dask.is_dask_collection(w) assert_allclose(x + 10, w) - assert s.task_state