Skip to content

Commit

Permalink
support backwards compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Nov 1, 2017
1 parent ffb0ca1 commit 56ec487
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 15 deletions.
10 changes: 4 additions & 6 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,8 @@ def compute(self, **kwargs):
--------
dask.array.compute
"""
import dask
(result,) = dask.compute(self, **kwargs)
return result
new = self.copy(deep=False)
return new.load(**kwargs)

def persist(self, **kwargs):
""" Trigger computation in constituent dask arrays
Expand All @@ -666,9 +665,8 @@ def persist(self, **kwargs):
--------
dask.persist
"""
import dask
(result,) = dask.persist(self, **kwargs)
return result
ds = self._to_temp_dataset().persist(**kwargs)
return self._from_temp_dataset(ds)

def copy(self, deep=True):
"""Returns a copy of this array.
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,8 @@ def persist(self, **kwargs):
--------
dask.persist
"""
import dask
(result,) = dask.persist(self, **kwargs)
return result
new = self.copy(deep=False)
return new._persist_inplace(**kwargs)

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,8 @@ def compute(self, **kwargs):
--------
dask.array.compute
"""
import dask
(result,) = dask.compute(self, **kwargs)
return result
new = self.copy(deep=False)
return new.load(**kwargs)

def __dask_graph__(self):
if isinstance(self._data, dask_array_type):
Expand Down
18 changes: 16 additions & 2 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def test_bivariate_ufunc(self):
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))

@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
reason='Need dask 0.16+ for new interface')
def test_compute(self):
u = self.eager_var
v = self.lazy_var
Expand All @@ -216,6 +218,8 @@ def test_compute(self):

assert ((u + 1).data == v2.data).all()

@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
reason='Need dask 0.16+ for new interface')
def test_persist(self):
u = self.eager_var
v = self.lazy_var + 1
Expand Down Expand Up @@ -275,6 +279,8 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
reason='Need dask 0.16+ for new interface')
def test_compute(self):
u = self.eager_array
v = self.lazy_array
Expand All @@ -285,6 +291,8 @@ def test_compute(self):

assert ((u + 1).data == v2.data).all()

@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
reason='Need dask 0.16+ for new interface')
def test_persist(self):
u = self.eager_array
v = self.lazy_array + 1
Expand Down Expand Up @@ -730,6 +738,8 @@ def build_dask_array(name):
chunks=((1,),), dtype=np.int64)


@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
reason='Need dask 0.16+ for new interface')
@pytest.mark.parametrize('persist', [lambda x: x.persist(),
lambda x: dask.persist(x)[0]])
def test_persist_Dataset(persist):
Expand All @@ -743,8 +753,12 @@ def test_persist_Dataset(persist):
assert len(ds2.foo.data.dask) == 1
assert len(ds.foo.data.dask) == n # doesn't mutate in place

@pytest.mark.parametrize('persist', [lambda x: x.persist(),
lambda x: dask.persist(x)[0]])
@pytest.mark.parametrize('persist', [
lambda x: x.persist(),
pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
lambda x: dask.persist(x)[0],
reason='Need Dask 0.16+')
])
def test_persist_DataArray(persist):
x = da.arange(10, chunks=(5,))
y = DataArray(x)
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def test_dask_distributed_integration_test(loop, engine):
assert_allclose(original, computed)


@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
reason='Need recent distributed version to clean up get')
@gen_cluster(client=True, timeout=None)
def test_async(c, s, a, b):
x = create_test_data()
Expand Down Expand Up @@ -63,5 +65,4 @@ def test_async(c, s, a, b):
assert not dask.is_dask_collection(w)
assert_allclose(x + 10, w)


assert s.task_state

0 comments on commit 56ec487

Please sign in to comment.