Skip to content

Commit

Permalink
Fix optimize for chunked DataArray (#4432)
Browse files Browse the repository at this point in the history
Previously we generated in invalidate Dask task graph, becuase the lines
removed here dropped keys that were referenced elsewhere in the task
graph. The original implementation had a
comment indicating that this was to cull:
https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384

Just spot-checking things, I think we're OK here though. Something like
`dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK.

Closes #3698

Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
TomAugspurger and max-sixty authored Sep 17, 2020
1 parent b0d8d93 commit 9a8a62b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ Bug fixes
- Fix `KeyError` when doing linear interpolation to an nd `DataArray`
that contains NaNs (:pull:`4233`).
By `Jens Svensmark <https://github.com/jenssss>`_
- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`)
- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`).
By `Peter Hausamann <https://github.com/phausamann>`_.
- Fix indexing with datetime64 scalars with pandas 1.1 (:issue:`4283`).
By `Stephan Hoyer <https://github.com/shoyer>`_ and
`Justus Magin <https://github.com/keewis>`_.


Documentation
~~~~~~~~~~~~~
Expand Down
3 changes: 0 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,6 @@ def __dask_postpersist__(self):

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
if isinstance(results, dict): # persist case
name = array_args[0]
results = {k: v for k, v in results.items() if k[0] == name}
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)

Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,3 +1607,10 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
assert_equal(map_da.astype(map_da.dtype), map_da)
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)


def test_optimize():
a = dask.array.ones((10, 5), chunks=(1, 3))
arr = xr.DataArray(a).chunk(5)
(arr2,) = dask.optimize(arr)
arr2.compute()

0 comments on commit 9a8a62b

Please sign in to comment.