Skip to content

Commit

Permalink
support dask arrays in DataArray coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Nov 3, 2017
1 parent 9df0af7 commit 3ea0dc1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
26 changes: 13 additions & 13 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,33 +577,33 @@ def reset_coords(self, names=None, drop=False, inplace=False):
return dataset

def __dask_graph__(self):
return self._variable.__dask_graph__()
return self._to_temp_dataset().__dask_graph__()

def __dask_keys__(self):
return self._variable.__dask_keys__()
return self._to_temp_dataset().__dask_keys__()

@property
def __dask_optimize__(self):
return self._variable.__dask_optimize__
return self._to_temp_dataset().__dask_optimize__

@property
def __dask_scheduler__(self):
return self._variable.__dask_scheduler__
return self._to_temp_dataset().__dask_optimize__

def __dask_postcompute__(self):
variable_func, variable_args = self._variable.__dask_postcompute__()
return self._dask_finalize, (variable_func, variable_args,
self._coords, self._name)
func, args = self._to_temp_dataset().__dask_postcompute__()
return self._dask_finalize, (func, args, self.name)

def __dask_postpersist__(self):
variable_func, variable_args = self._variable.__dask_postpersist__()
return self._dask_finalize, (variable_func, variable_args,
self._coords, self._name)
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (func, args, self.name)

@staticmethod
def _dask_finalize(results, variable_func, variable_args, coords, name):
var = variable_func(results, *variable_args)
return DataArray(var, coords=coords, name=name)
def _dask_finalize(results, func, args, name):
ds = func(results, *args)
variable = ds._variables.pop(_THIS_ARRAY)
coords = ds._variables
return DataArray(variable, coords, name=name, fastpath=True)

def load(self, **kwargs):
"""Manually trigger loading of this array's data from disk or a
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __dask_postpersist__(self):
@staticmethod
def _dask_postcompute(results, info, *args):
variables = OrderedDict()
results2 = results[::-1]
results2 = list(results[::-1])
for is_dask, k, v in info:
if is_dask:
func, args2 = v
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,22 @@ def test_persist_DataArray(persist):

assert len(z.data.dask) == n
assert len(zz.data.dask) == zz.data.npartitions


def test_dataarray_with_dask_coords():
import toolz
x = xr.Variable('x', da.arange(8, chunks=(4,)))
y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2)
data = da.random.random((8, 8), chunks=(4, 4)) + 1
array = xr.DataArray(data, dims=['x', 'y'])
array.coords['xx'] = x
array.coords['yy'] = y

assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(),
x.__dask_graph__(),
y.__dask_graph__())

(array2,) = dask.compute(array)
assert not dask.is_dask_collection(array2)

assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values())

0 comments on commit 3ea0dc1

Please sign in to comment.