Skip to content
forked from pydata/xarray

Commit

Permalink
Fixed dask.optimize on datasets (pydata#4438)
Browse files Browse the repository at this point in the history
* Fixed dask.optimize on datasets

Another attempt to fix pydata#3698. The issue with my fix in is that we hit
`Variable._dask_finalize` in both `dask.optimize` and `dask.persist`. We
want to do the culling of unnecessary tasks (`test_persist_Dataset`) but
only in the persist case, not optimize (`test_optimize`).

* Update whats-new.rst

* Update doc/whats-new.rst

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2020
1 parent 0c26211 commit 13c09dc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ Bug fixes
By `Jens Svensmark <https://github.com/jenssss>`_
- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`).
By `Peter Hausamann <https://github.com/phausamann>`_.
- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`)
By `Tom Augspurger <https://github.com/TomAugspurger>`_
- Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source
directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH.
By `Guido Imperiale <https://github.com/crusaderky>`_
Expand All @@ -109,7 +111,6 @@ Bug fixes
- Avoid relying on :py:class:`set` objects for the ordering of the coordinates (:pull:`4409`)
By `Justus Magin <https://github.com/keewis>`_.


Documentation
~~~~~~~~~~~~~

Expand Down
11 changes: 10 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,10 +777,19 @@ def _dask_postcompute(results, info, *args):
@staticmethod
def _dask_postpersist(dsk, info, *args):
variables = {}
# postpersist is called in both dask.optimize and dask.persist
# When persisting, we want to filter out unrelated keys for
# each Variable's task graph.
is_persist = len(dsk) == len(info)
for is_dask, k, v in info:
if is_dask:
func, args2 = v
result = func(dsk, *args2)
if is_persist:
name = args2[1][0]
dsk2 = {k: v for k, v in dsk.items() if k[0] == name}
else:
dsk2 = dsk
result = func(dsk2, *args2)
else:
result = v
variables[k] = result
Expand Down
3 changes: 0 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,6 @@ def __dask_postpersist__(self):

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
if isinstance(results, dict): # persist case
name = array_args[0]
results = {k: v for k, v in results.items() if k[0] == name}
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)

Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,3 +1607,11 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
assert_equal(map_da.astype(map_da.dtype), map_da)
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)


def test_optimize():
# https://github.com/pydata/xarray/issues/3698
a = dask.array.ones((10, 4), chunks=(5, 2))
arr = xr.DataArray(a).chunk(5)
(arr2,) = dask.optimize(arr)
arr2.compute()

0 comments on commit 13c09dc

Please sign in to comment.