diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 075d79043b2..8f4ff684b08 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -501,9 +501,6 @@ def __dask_postpersist__(self): @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - if isinstance(results, dict): # persist case - name = array_args[0] - results = {k: v for k, v in results.items() if k[0] == name} data = array_func(results, *array_args) return Variable(dims, data, attrs=attrs, encoding=encoding) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 46685a29a47..489bf09fa3c 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1607,3 +1607,10 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da) assert_equal(map_da.astype(map_da.dtype), map_da) assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy) + + +def test_optimize(): + a = dask.array.ones((10, 5), chunks=(1, 3)) + arr = xr.DataArray(a).chunk(5) + (arr2,) = dask.optimize(arr) + arr2.compute()