From 20195ca2112b68280bb200be6cdc8a357dccb1d8 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 17 Sep 2020 14:58:01 -0500 Subject: [PATCH] Fix optimize for chunked DataArray Previously we generated in invalidate Dask task graph, becuase the lines removed here dropped keys that were referenced elsewhere in the task graph. The original implementation had a comment indicating that this was to cull: https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384 Just spot-checking things, I think we're OK here though. Something like `dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK. Closes https://github.com/pydata/xarray/issues/3698 --- doc/whats-new.rst | 1 + xarray/core/variable.py | 3 --- xarray/tests/test_dask.py | 7 +++++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f367286a244..023f4c1220d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -84,6 +84,7 @@ Bug fixes - Fix `KeyError` when doing linear interpolation to an nd `DataArray` that contains NaNs (:pull:`4233`). By `Jens Svensmark `_ +- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 075d79043b2..8f4ff684b08 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -501,9 +501,6 @@ def __dask_postpersist__(self): @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - if isinstance(results, dict): # persist case - name = array_args[0] - results = {k: v for k, v in results.items() if k[0] == name} data = array_func(results, *array_args) return Variable(dims, data, attrs=attrs, encoding=encoding) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 46685a29a47..489bf09fa3c 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1607,3 +1607,10 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da) assert_equal(map_da.astype(map_da.dtype), map_da) assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy) + + +def test_optimize(): + a = dask.array.ones((10, 5), chunks=(1, 3)) + arr = xr.DataArray(a).chunk(5) + (arr2,) = dask.optimize(arr) + arr2.compute()