Skip to content

Commit

Permalink
map_blocks: Allow passing dask-backed objects in args (#3818)
Browse files Browse the repository at this point in the history
* MVP for dask collections in args

* Add tests.

* Use list comprehension

* map_blocks: preserve attrs of dimension coordinates in input

Switch to use IndexVariables instead of Indexes so that attrs are preserved.

* Check that chunk sizes are compatible.

* Align all xarray objects

* Add some type hints.

* fix rebase

* move _wrapper out

* Fixes

* avoid index dataarrays for simplicity.

need a solution to preserve index attrs

* Propagate attributes for index variables.

* Propagate encoding for index variables.

* Fix bug with reductions when template is provided.

indexes should just have indexes for output variable. When template was
provided, I was initializing to indexes to contain all input indexes.
It should just have the indexes from template. Otherwise indexes for
any indexed dimensions removed by func will still be propagated.

* more minimal fix.

* minimize diff

* Update docs.

* Address joe comments.

* docstring updates.

* minor docstring change

* minor.

* remove useless check_shapes variable.

* fix docstring
  • Loading branch information
dcherian authored Jun 7, 2020
1 parent c07160d commit 2a288f6
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 114 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ New Features
- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases
where the result of a computation could not be inferred automatically.
By `Deepak Cherian <https://github.com/dcherian>`_
- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`)
By `Deepak Cherian <https://github.com/dcherian>`_

- Add keyword ``decode_timedelta`` to :py:func:`xarray.open_dataset`,
(:py:func:`xarray.open_dataarray`, :py:func:`xarray.open_dataarray`,
Expand Down
84 changes: 65 additions & 19 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,45 +3262,91 @@ def map_blocks(
----------
func: callable
User-provided function that accepts a DataArray as its first
parameter. The function will receive a subset, i.e. one block, of this DataArray
(see below), corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(block_subset, *args, **kwargs)``.
parameter. The function will receive a subset or 'block' of this DataArray (see below),
corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(subset_dataarray, *subset_args, **kwargs)``.
This function must return either a single DataArray or a single Dataset.
This function cannot add a new chunked dimension.
obj: DataArray, Dataset
Passed to the function as its first argument, one block at a time.
args: Sequence
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
objects, if any, will not be split by chunks. Passing dask collections is
not allowed.
Passed to func after unpacking and subsetting any xarray objects by blocks.
xarray objects in args must be aligned with obj, otherwise an error is raised.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
subset to blocks. Passing dask collections in kwargs is not allowed.
template: (optional) DataArray, Dataset
xarray object representing the final result after compute is called. If not provided,
the function will be first run on mocked-up data, that looks like 'obj' but
the function will be first run on mocked-up data, that looks like ``obj`` but
has sizes 0, to determine properties of the returned object such as dtype,
variable names, new dimensions and new indexes (if any).
'template' must be provided if the function changes the size of existing dimensions.
variable names, attributes, new dimensions and new indexes (if any).
``template`` must be provided if the function changes the size of existing dimensions.
When provided, ``attrs`` on variables in `template` are copied over to the result. Any
``attrs`` set by ``func`` will be ignored.
Returns
-------
A single DataArray or Dataset with dask backend, reassembled from the outputs of
the function.
A single DataArray or Dataset with dask backend, reassembled from the outputs of the
function.
Notes
-----
This method is designed for when one needs to manipulate a whole xarray object
within each chunk. In the more common case where one can work on numpy arrays,
it is recommended to use apply_ufunc.
This function is designed for when ``func`` needs to manipulate a whole xarray object
subset to each block. In the more common case where ``func`` can work on numpy arrays, it is
recommended to use ``apply_ufunc``.
If none of the variables in this DataArray is backed by dask, calling this
method is equivalent to calling ``func(self, *args, **kwargs)``.
If none of the variables in ``obj`` is backed by dask arrays, calling this function is
equivalent to calling ``func(obj, *args, **kwargs)``.
See Also
--------
dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks,
xarray.Dataset.map_blocks
dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks,
xarray.DataArray.map_blocks
Examples
--------
Calculate an anomaly from climatology using ``.groupby()``. Using
``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``,
its indices, and its methods like ``.groupby()``.
>>> def calculate_anomaly(da, groupby_type="time.month"):
... gb = da.groupby(groupby_type)
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... ).chunk()
>>> array.map_blocks(calculate_anomaly, template=array).compute()
<xarray.DataArray (time: 24)>
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
-0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 ,
0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:
>>> array.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
... )
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""
from .parallel import map_blocks

Expand Down
83 changes: 65 additions & 18 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5733,45 +5733,92 @@ def map_blocks(
----------
func: callable
User-provided function that accepts a Dataset as its first
parameter. The function will receive a subset, i.e. one block, of this Dataset
(see below), corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(block_subset, *args, **kwargs)``.
parameter. The function will receive a subset or 'block' of this Dataset (see below),
corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(subset_dataset, *subset_args, **kwargs)``.
This function must return either a single DataArray or a single Dataset.
This function cannot add a new chunked dimension.
obj: DataArray, Dataset
Passed to the function as its first argument, one block at a time.
args: Sequence
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
objects, if any, will not be split by chunks. Passing dask collections is
not allowed.
Passed to func after unpacking and subsetting any xarray objects by blocks.
xarray objects in args must be aligned with obj, otherwise an error is raised.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
subset to blocks. Passing dask collections in kwargs is not allowed.
template: (optional) DataArray, Dataset
xarray object representing the final result after compute is called. If not provided,
the function will be first run on mocked-up data, that looks like 'obj' but
the function will be first run on mocked-up data, that looks like ``obj`` but
has sizes 0, to determine properties of the returned object such as dtype,
variable names, new dimensions and new indexes (if any).
'template' must be provided if the function changes the size of existing dimensions.
variable names, attributes, new dimensions and new indexes (if any).
``template`` must be provided if the function changes the size of existing dimensions.
When provided, ``attrs`` on variables in `template` are copied over to the result. Any
``attrs`` set by ``func`` will be ignored.
Returns
-------
A single DataArray or Dataset with dask backend, reassembled from the outputs of
the function.
A single DataArray or Dataset with dask backend, reassembled from the outputs of the
function.
Notes
-----
This method is designed for when one needs to manipulate a whole xarray object
within each chunk. In the more common case where one can work on numpy arrays,
it is recommended to use apply_ufunc.
This function is designed for when ``func`` needs to manipulate a whole xarray object
subset to each block. In the more common case where ``func`` can work on numpy arrays, it is
recommended to use ``apply_ufunc``.
If none of the variables in this Dataset is backed by dask, calling this method
is equivalent to calling ``func(self, *args, **kwargs)``.
If none of the variables in ``obj`` is backed by dask arrays, calling this function is
equivalent to calling ``func(obj, *args, **kwargs)``.
See Also
--------
dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks,
dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks,
xarray.DataArray.map_blocks
Examples
--------
Calculate an anomaly from climatology using ``.groupby()``. Using
``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``,
its indices, and its methods like ``.groupby()``.
>>> def calculate_anomaly(da, groupby_type="time.month"):
... gb = da.groupby(groupby_type)
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... ).chunk()
>>> ds = xr.Dataset({"a": array})
>>> ds.map_blocks(calculate_anomaly, template=ds).compute()
<xarray.DataArray (time: 24)>
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
-0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 ,
0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:
>>> ds.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=ds,
... )
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""
from .parallel import map_blocks

Expand Down
Loading

0 comments on commit 2a288f6

Please sign in to comment.