Skip to content

Commit

Permalink
Fix optimize for chunked DataArray
Browse files Browse the repository at this point in the history
Previously we generated in invalidate Dask task graph, becuase the lines
removed here dropped keys that were referenced elsewhere in the task
graph. The original implementation had a
comment indicating that this was to cull:
https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384

Just spot-checking things, I think we're OK here though. Something like
`dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK.

Closes pydata#3698
  • Loading branch information
TomAugspurger committed Sep 17, 2020
1 parent bb4c7b4 commit 20195ca
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Bug fixes
- Fix `KeyError` when doing linear interpolation to an nd `DataArray`
that contains NaNs (:pull:`4233`).
By `Jens Svensmark <https://github.com/jenssss>`_
- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`)

Documentation
~~~~~~~~~~~~~
Expand Down
3 changes: 0 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,6 @@ def __dask_postpersist__(self):

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
if isinstance(results, dict): # persist case
name = array_args[0]
results = {k: v for k, v in results.items() if k[0] == name}
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)

Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,3 +1607,10 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
assert_equal(map_da.astype(map_da.dtype), map_da)
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)


def test_optimize():
a = dask.array.ones((10, 5), chunks=(1, 3))
arr = xr.DataArray(a).chunk(5)
(arr2,) = dask.optimize(arr)
arr2.compute()

0 comments on commit 20195ca

Please sign in to comment.