Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask.optimize on xarray objects #3698

Closed
dcherian opened this issue Jan 15, 2020 · 5 comments · Fixed by #4432 or #4438
Closed

dask.optimize on xarray objects #3698

dcherian opened this issue Jan 15, 2020 · 5 comments · Fixed by #4432 or #4438

Comments

@dcherian
Copy link
Contributor

I am trying to call dask.optimize on a xarray object before the graph gets too big. But get weird errors. Simple examples below. All examples work if I remove the dask.optimize step.

cc @mrocklin @shoyer

This works with dask arrays:

a = dask.array.ones((10,5), chunks=(1,3))
a = dask.optimize(a)[0]
a.compute()

It works when a dataArray is constructed using a dask array

da = xr.DataArray(a)
da = dask.optimize(da)[0]
da.compute()

but fails when creating a DataArray with a numpy array and then chunking it

🤷‍♂️

da = xr.DataArray(a.compute()).chunk({"dim_0": 5})
da = dask.optimize(da)[0]
da.compute()

fails with error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-50-1f16efa19800> in <module>
      1 da = xr.DataArray(a.compute()).chunk({"dim_0": 5})
      2 da = dask.optimize(da)[0]
----> 3 da.compute()

~/python/xarray/xarray/core/dataarray.py in compute(self, **kwargs)
    838         """
    839         new = self.copy(deep=False)
--> 840         return new.load(**kwargs)
    841 
    842     def persist(self, **kwargs) -> "DataArray":

~/python/xarray/xarray/core/dataarray.py in load(self, **kwargs)
    812         dask.array.compute
    813         """
--> 814         ds = self._to_temp_dataset().load(**kwargs)
    815         new = self._from_temp_dataset(ds)
    816         self._variable = new._variable

~/python/xarray/xarray/core/dataset.py in load(self, **kwargs)
    659 
    660             # evaluate all the dask arrays simultaneously
--> 661             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    662 
    663             for k, data in zip(lazy_data, evaluated_data):

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     79         get_id=_thread_get_id,
     80         pack_exception=pack_exception,
---> 81         **kwargs
     82     )
     83 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

TypeError: string indices must be integers

And a different error when rechunking a dask-backed DataArray

da = xr.DataArray(a).chunk({"dim_0": 5})
da = dask.optimize(da)[0]
da.compute()
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-55-d978bbb9e38d> in <module>
      1 da = xr.DataArray(a).chunk({"dim_0": 5})
      2 da = dask.optimize(da)[0]
----> 3 da.compute()

~/python/xarray/xarray/core/dataarray.py in compute(self, **kwargs)
    838         """
    839         new = self.copy(deep=False)
--> 840         return new.load(**kwargs)
    841 
    842     def persist(self, **kwargs) -> "DataArray":

~/python/xarray/xarray/core/dataarray.py in load(self, **kwargs)
    812         dask.array.compute
    813         """
--> 814         ds = self._to_temp_dataset().load(**kwargs)
    815         new = self._from_temp_dataset(ds)
    816         self._variable = new._variable

~/python/xarray/xarray/core/dataset.py in load(self, **kwargs)
    659 
    660             # evaluate all the dask arrays simultaneously
--> 661             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    662 
    663             for k, data in zip(lazy_data, evaluated_data):

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     79         get_id=_thread_get_id,
     80         pack_exception=pack_exception,
---> 81         **kwargs
     82     )
     83 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in concatenate3(arrays)
   4305     if not ndim:
   4306         return arrays
-> 4307     chunks = chunks_from_arrays(arrays)
   4308     shape = tuple(map(sum, chunks))
   4309 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in chunks_from_arrays(arrays)
   4085 
   4086     while isinstance(arrays, (list, tuple)):
-> 4087         result.append(tuple([shape(deepfirst(a))[dim] for a in arrays]))
   4088         arrays = arrays[0]
   4089         dim += 1

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in <listcomp>(.0)
   4085 
   4086     while isinstance(arrays, (list, tuple)):
-> 4087         result.append(tuple([shape(deepfirst(a))[dim] for a in arrays]))
   4088         arrays = arrays[0]
   4089         dim += 1

IndexError: tuple index out of range
@TomAugspurger
Copy link
Contributor

It looks like xarray is getting a bad task graph after the optimize.

In [1]: import xarray as xr
import dask
In [2]: import dask

In [3]: a = dask.array.ones((10,5), chunks=(1,3))
   ...: a = dask.optimize(a)[0]

In [4]: da = xr.DataArray(a.compute()).chunk({"dim_0": 5})
   ...: da = dask.optimize(da)[0]

In [5]: dict(da.__dask_graph__())
Out[5]:
{('xarray-<this-array>-e2865aa10d476e027154771611541f99',
  1,
  0): (<function _operator.getitem(a, b, /)>, 'xarray-<this-array>-e2865aa10d476e027154771611541f99', (slice(5, 10, None),
   slice(0, 5, None))),
 ('xarray-<this-array>-e2865aa10d476e027154771611541f99',
  0,
  0): (<function _operator.getitem(a, b, /)>, 'xarray-<this-array>-e2865aa10d476e027154771611541f99', (slice(0, 5, None),
   slice(0, 5, None)))}

Notice that are references to xarray-<this-array>-e2865aa10d476e027154771611541f99 (just the string, not a tuple representing a chunk) but that key isn't in the graph.

If we manually insert that, you'll see things work

In [9]: dsk['xarray-<this-array>-e2865aa10d476e027154771611541f99'] = da._to_temp_dataset()[xr.core.dataarray._THIS_ARRAY]

In [11]: dask.get(dsk, keys=[('xarray-<this-array>-e2865aa10d476e027154771611541f99', 1, 0)])
Out[11]:
(<xarray.DataArray <this-array> (dim_0: 5, dim_1: 5)>
 dask.array<getitem, shape=(5, 5), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
 Dimensions without coordinates: dim_0, dim_1,)

@TomAugspurger
Copy link
Contributor

FYI, @dcherian your recent PR to dask fixed this example. Playing around with chunk sizes, it seems to have fixed it even when the chunk size exceeds dask.config['array']['chunk-size'].

@dcherian
Copy link
Contributor Author

dcherian commented Sep 9, 2020

I guess I can see that. Thanks Tom.

it even when the chunk size exceeds dask.config['array']['chunk-size']

FYI the slicing behaviour is independent of chunk-size (matt's recommendation).

@dcherian dcherian closed this as completed Sep 9, 2020
@dcherian
Copy link
Contributor Author

The numpy example is fixed but the dask rechunked example is still broken.

a = dask.array.ones((10,5), chunks=(1,3))
dask.optimize(xr.DataArray(a))[0].compute() # works
dask.optimize(xr.DataArray(a).chunk(5))[0].compute()  # error
IndexError                                Traceback (most recent call last)
<ipython-input-8-5663bc8bc82a> in <module>
----> 1 dask.optimize(xr.DataArray(a).chunk(5))[0].compute()

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/xarray/core/dataarray.py in compute(self, **kwargs)
    838         """
    839         new = self.copy(deep=False)
--> 840         return new.load(**kwargs)
    841 
    842     def persist(self, **kwargs) -> "DataArray":

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/xarray/core/dataarray.py in load(self, **kwargs)
    812         dask.array.compute
    813         """
--> 814         ds = self._to_temp_dataset().load(**kwargs)
    815         new = self._from_temp_dataset(ds)
    816         self._variable = new._variable

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/xarray/core/dataset.py in load(self, **kwargs)
    656 
    657             # evaluate all the dask arrays simultaneously
--> 658             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    659 
    660             for k, data in zip(lazy_data, evaluated_data):

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/base.py in compute(*args, **kwargs)
    445         postcomputes.append(x.__dask_postcompute__())
    446 
--> 447     results = schedule(dsk, keys, **kwargs)
    448     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    449 

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     74                 pools[thread][num_workers] = pool
     75 
---> 76     results = get_async(
     77         pool.apply_async,
     78         len(pool._pool),

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/array/core.py in concatenate3(arrays)
   4407     if not ndim:
   4408         return arrays
-> 4409     chunks = chunks_from_arrays(arrays)
   4410     shape = tuple(map(sum, chunks))
   4411 

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/array/core.py in chunks_from_arrays(arrays)
   4178 
   4179     while isinstance(arrays, (list, tuple)):
-> 4180         result.append(tuple([shape(deepfirst(a))[dim] for a in arrays]))
   4181         arrays = arrays[0]
   4182         dim += 1

~/miniconda3/envs/dcpy/lib/python3.8/site-packages/dask/array/core.py in <listcomp>(.0)
   4178 
   4179     while isinstance(arrays, (list, tuple)):
-> 4180         result.append(tuple([shape(deepfirst(a))[dim] for a in arrays]))
   4181         arrays = arrays[0]
   4182         dim += 1

IndexError: tuple index out of range

@dcherian dcherian reopened this Sep 10, 2020
@TomAugspurger
Copy link
Contributor

TomAugspurger commented Sep 10, 2020 via email

TomAugspurger added a commit to TomAugspurger/xarray that referenced this issue Sep 17, 2020
Previously we generated in invalidate Dask task graph, becuase the lines
removed here dropped keys that were referenced elsewhere in the task
graph. The original implementation had a
comment indicating that this was to cull:
https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384

Just spot-checking things, I think we're OK here though. Something like
`dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK.

Closes pydata#3698
TomAugspurger added a commit to TomAugspurger/xarray that referenced this issue Sep 17, 2020
Previously we generated in invalidate Dask task graph, becuase the lines
removed here dropped keys that were referenced elsewhere in the task
graph. The original implementation had a
comment indicating that this was to cull:
https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384

Just spot-checking things, I think we're OK here though. Something like
`dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK.

Closes pydata#3698
max-sixty added a commit that referenced this issue Sep 17, 2020
Previously we generated in invalidate Dask task graph, becuase the lines
removed here dropped keys that were referenced elsewhere in the task
graph. The original implementation had a
comment indicating that this was to cull:
https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384

Just spot-checking things, I think we're OK here though. Something like
`dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK.

Closes #3698

Co-authored-by: Maximilian Roos <[email protected]>
TomAugspurger added a commit to TomAugspurger/xarray that referenced this issue Sep 18, 2020
Another attempt to fix pydata#3698. The issue with my fix in is that we hit
`Variable._dask_finalize` in both `dask.optimize` and `dask.persist`. We
want to do the culling of unnecessary tasks (`test_persist_Dataset`) but
only in the persist case, not optimize (`test_optimize`).
TomAugspurger added a commit to TomAugspurger/xarray that referenced this issue Sep 18, 2020
Another attempt to fix pydata#3698. The issue with my fix in is that we hit
`Variable._dask_finalize` in both `dask.optimize` and `dask.persist`. We
want to do the culling of unnecessary tasks (`test_persist_Dataset`) but
only in the persist case, not optimize (`test_optimize`).
@dcherian dcherian reopened this Sep 18, 2020
max-sixty added a commit that referenced this issue Sep 20, 2020
* Fixed dask.optimize on datasets

Another attempt to fix #3698. The issue with my fix in is that we hit
`Variable._dask_finalize` in both `dask.optimize` and `dask.persist`. We
want to do the culling of unnecessary tasks (`test_persist_Dataset`) but
only in the persist case, not optimize (`test_optimize`).

* Update whats-new.rst

* Update doc/whats-new.rst

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants