Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: use dask.array.apply_gufunc in xr.apply_ufunc #4060

Merged
merged 45 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c18655f
ENH: use `dask.array.apply_gufunc` in `xr.apply_ufunc` for multiple o…
kmuehlbauer May 14, 2020
3bf1d75
DOC: Update docstring and whats-new.rst
kmuehlbauer May 14, 2020
fff7660
WIP: apply_gufunc
kmuehlbauer May 20, 2020
d8bcb15
WIP: apply_gufunc -> reinstate dask='allowed' as per @mathause, adapt…
kmuehlbauer May 20, 2020
a17ca32
WIP: apply_gufunc -> add test for GH #4015, fix test for sparse meta …
kmuehlbauer May 20, 2020
5f3f847
WIP: apply_gufunc -> remove unused `input_dims`
kmuehlbauer May 20, 2020
c9fd698
Update xarray/core/computation.py
kmuehlbauer May 27, 2020
fc30ee8
Update xarray/core/computation.py
kmuehlbauer May 27, 2020
820a4cd
Update xarray/core/computation.py
kmuehlbauer May 27, 2020
9a266dc
Merge remote-tracking branch 'origin/master' into fix-1815
kmuehlbauer May 27, 2020
76375b4
WIP: use dask_gufunc_kwargs, keep vectorize first but only for non-da…
kmuehlbauer May 28, 2020
6cff763
DOC: add reference to internal changes in whats-new.rst
kmuehlbauer May 28, 2020
eb953d2
FIX: mypy
kmuehlbauer May 28, 2020
4ee835d
FIX: vectorize inside `apply_variable_ufunc`
kmuehlbauer May 28, 2020
ddeb1ea
TST: add tests from #4022 from @mathause
kmuehlbauer May 28, 2020
f645bbd
FIX: address black issue
kmuehlbauer May 28, 2020
c9e30af
FIX: xfail test for dask < 2.3
kmuehlbauer May 28, 2020
bdfdd74
Merge remote-tracking branch 'origin/master' into wip-apply-gufunc
kmuehlbauer Jun 9, 2020
f9cb53c
WIP: apply changes in response to @mathause's review comments
kmuehlbauer Jun 9, 2020
3b652fb
WIP: remove line
kmuehlbauer Jun 9, 2020
13f5c1d
WIP: catch different chunksize error and allow_rechunk, docstring fixes
kmuehlbauer Jun 9, 2020
512d55d
WIP: remove comment
kmuehlbauer Jun 9, 2020
5ad5063
WIP: style issues
kmuehlbauer Jun 9, 2020
9f51c77
WIP: revert catch, revert test, add tests without output_dtypes
kmuehlbauer Jun 9, 2020
41feeb3
Merge remote-tracking branch 'origin/master' into fix-1815
kmuehlbauer Jun 29, 2020
ada9cf0
WIP: fix signature in apply_ufunc->apply_gufunc, handle output_sizes,…
kmuehlbauer Jun 29, 2020
1bdf2bb
WIP: fix tuple
kmuehlbauer Jun 29, 2020
1401551
Merge remote-tracking branch 'origin/master' into fix-1815
kmuehlbauer Jun 30, 2020
fb64ed9
WIP: add dims_map to _UFuncSignature, adapt output_sizes to fit for a…
kmuehlbauer Jun 30, 2020
b0b5e2e
WIP: black
kmuehlbauer Jun 30, 2020
d0b4fb2
WIP: raise ValueError if output_sizes dimension mismatch
kmuehlbauer Jun 30, 2020
1bca9de
WIP: raise ValueError if output_sizes is missing for given output_cor…
kmuehlbauer Jun 30, 2020
1d5a8bb
WIP: simplify if/else
kmuehlbauer Jun 30, 2020
91a9d5a
FIX: resolve conflicts prior merge with master
kmuehlbauer Aug 17, 2020
035aa17
Merge remote-tracking branch 'origin/master' into fix-1815
kmuehlbauer Aug 17, 2020
e116fd0
FIX: combine if's as per review
kmuehlbauer Aug 17, 2020
4db4444
FIX: pass `vectorize` and `output_dtypes` kwargs explicitely into `ap…
kmuehlbauer Aug 17, 2020
9897c51
FIX: pass `vectorize` and `output_dtypes` kwargs explicitely into `da…
kmuehlbauer Aug 17, 2020
d4902a6
FIX: address review comments of @keewis and @mathause
kmuehlbauer Aug 17, 2020
5a1f15e
FIX: black
kmuehlbauer Aug 17, 2020
a05bd18
FIX: `vectorize` not needed in if-clause
kmuehlbauer Aug 17, 2020
35ae2a9
Merge remote-tracking branch 'origin/master' into fix-1815
kmuehlbauer Aug 18, 2020
2fc6272
FIX: set DeprecationWarning and stacklevel=2
kmuehlbauer Aug 18, 2020
4cb059e
FIX: use FutureWarning for user visibility
kmuehlbauer Aug 18, 2020
4a48acd
FIX: remove comment as suggested
kmuehlbauer Aug 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Breaking changes

New Features
~~~~~~~~~~~~
- Support multiple outputs in :py:func:`xarray.apply_ufunc` when using ``dask='parallelized'``. (:issue:`1815`, :pull:`4060`)
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling`
now accept more than 1 dimension.(:pull:`4219`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
Expand Down Expand Up @@ -75,6 +77,8 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~
- Use :py:func:`dask.array.apply_gufunc` instead of :py:func:`dask.array.blockwise` in
:py:func:`xarray.apply_ufunc` when using ``dask='parallelized'``. (:pull:`4060`)
- Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source
directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH.
By `Guido Imperiale <https://github.com/crusaderky>`_
Expand Down
238 changes: 116 additions & 122 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import functools
import itertools
import operator
import warnings
from collections import Counter
from distutils.version import LooseVersion
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -90,6 +92,12 @@ def all_core_dims(self):
self._all_core_dims = self.all_input_core_dims | self.all_output_core_dims
return self._all_core_dims

@property
def dims_map(self):
return {
core_dim: f"dim{n}" for n, core_dim in enumerate(sorted(self.all_core_dims))
}

@property
def num_inputs(self):
return len(self.input_core_dims)
Expand Down Expand Up @@ -126,14 +134,12 @@ def to_gufunc_string(self):
Unlike __str__, handles dimensions that don't map to Python
identifiers.
"""
all_dims = self.all_core_dims
dims_map = dict(zip(sorted(all_dims), range(len(all_dims))))
input_core_dims = [
["dim%d" % dims_map[dim] for dim in core_dims]
[self.dims_map[dim] for dim in core_dims]
for core_dims in self.input_core_dims
]
output_core_dims = [
["dim%d" % dims_map[dim] for dim in core_dims]
[self.dims_map[dim] for dim in core_dims]
for core_dims in self.output_core_dims
]
alt_signature = type(self)(input_core_dims, output_core_dims)
Expand Down Expand Up @@ -424,7 +430,7 @@ def apply_groupby_func(func, *args):
if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]):
raise ValueError(
"apply_ufunc can only perform operations over "
"multiple GroupBy objets at once if they are all "
"multiple GroupBy objects at once if they are all "
"grouped the same way"
)

Expand Down Expand Up @@ -539,16 +545,27 @@ def broadcast_compat_data(
return data


def _vectorize(func, signature, output_dtypes):
if signature.all_core_dims:
func = np.vectorize(
func, otypes=output_dtypes, signature=signature.to_gufunc_string()
)
else:
func = np.vectorize(func, otypes=output_dtypes)

return func


def apply_variable_ufunc(
func,
*args,
signature,
exclude_dims=frozenset(),
dask="forbidden",
output_dtypes=None,
output_sizes=None,
vectorize=False,
keep_attrs=False,
meta=None,
dask_gufunc_kwargs=None,
):
"""Apply a ndarray level function over Variable and/or ndarray objects.
"""
Expand Down Expand Up @@ -579,28 +596,61 @@ def apply_variable_ufunc(
"``.load()`` or ``.compute()``"
)
elif dask == "parallelized":
input_dims = [broadcast_dims + dims for dims in signature.input_core_dims]
numpy_func = func

if dask_gufunc_kwargs is None:
dask_gufunc_kwargs = {}

output_sizes = dask_gufunc_kwargs.pop("output_sizes", {})
if output_sizes:
output_sizes_renamed = {}
for key, value in output_sizes.items():
if key not in signature.all_output_core_dims:
raise ValueError(
f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims"
)
output_sizes_renamed[signature.dims_map[key]] = value
dask_gufunc_kwargs["output_sizes"] = output_sizes_renamed

for key in signature.all_output_core_dims:
if key not in signature.all_input_core_dims and key not in output_sizes:
raise ValueError(
f"dimension '{key}' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'"
)

def func(*arrays):
return _apply_blockwise(
import dask.array as da

res = da.apply_gufunc(
numpy_func,
arrays,
input_dims,
output_dims,
signature,
output_dtypes,
output_sizes,
meta,
signature.to_gufunc_string(),
*arrays,
vectorize=vectorize,
output_dtypes=output_dtypes,
**dask_gufunc_kwargs,
)

# todo: covers for https://github.com/dask/dask/pull/6207
# remove when minimal dask version >= 2.17.0
from dask import __version__ as dask_version

if LooseVersion(dask_version) < LooseVersion("2.17.0"):
if signature.num_outputs > 1:
res = tuple(res)

return res

elif dask == "allowed":
pass
else:
raise ValueError(
"unknown setting for dask array handling in "
"apply_ufunc: {}".format(dask)
)
else:
if vectorize:
func = _vectorize(func, signature, output_dtypes=output_dtypes)

result_data = func(*input_data)

if signature.num_outputs == 1:
Expand Down Expand Up @@ -648,90 +698,6 @@ def func(*arrays):
return tuple(output)


def _apply_blockwise(
func,
args,
input_dims,
output_dims,
signature,
output_dtypes,
output_sizes=None,
meta=None,
):
import dask.array

if signature.num_outputs > 1:
raise NotImplementedError(
"multiple outputs from apply_ufunc not yet "
"supported with dask='parallelized'"
)

if output_dtypes is None:
raise ValueError(
"output dtypes (output_dtypes) must be supplied to "
"apply_func when using dask='parallelized'"
)
if not isinstance(output_dtypes, list):
raise TypeError(
"output_dtypes must be a list of objects coercible to "
"numpy dtypes, got {}".format(output_dtypes)
)
if len(output_dtypes) != signature.num_outputs:
raise ValueError(
"apply_ufunc arguments output_dtypes and "
"output_core_dims must have the same length: {} vs {}".format(
len(output_dtypes), signature.num_outputs
)
)
(dtype,) = output_dtypes

if output_sizes is None:
output_sizes = {}

new_dims = signature.all_output_core_dims - signature.all_input_core_dims
if any(dim not in output_sizes for dim in new_dims):
raise ValueError(
"when using dask='parallelized' with apply_ufunc, "
"output core dimensions not found on inputs must "
"have explicitly set sizes with ``output_sizes``: {}".format(new_dims)
)

for n, (data, core_dims) in enumerate(zip(args, signature.input_core_dims)):
if isinstance(data, dask_array_type):
# core dimensions cannot span multiple chunks
for axis, dim in enumerate(core_dims, start=-len(core_dims)):
if len(data.chunks[axis]) != 1:
raise ValueError(
"dimension {!r} on {}th function argument to "
"apply_ufunc with dask='parallelized' consists of "
"multiple chunks, but is also a core dimension. To "
"fix, rechunk into a single dask array chunk along "
"this dimension, i.e., ``.chunk({})``, but beware "
"that this may significantly increase memory usage.".format(
dim, n, {dim: -1}
)
)

(out_ind,) = output_dims

blockwise_args = []
for arg, dims in zip(args, input_dims):
# skip leading dimensions that are implicitly added by broadcasting
ndim = getattr(arg, "ndim", 0)
trimmed_dims = dims[-ndim:] if ndim else ()
blockwise_args.extend([arg, trimmed_dims])

return dask.array.blockwise(
func,
out_ind,
*blockwise_args,
dtype=dtype,
concatenate=True,
new_axes=output_sizes,
meta=meta,
)


def apply_array_ufunc(func, *args, dask="forbidden"):
"""Apply a ndarray level function over ndarray objects."""
if any(isinstance(arg, dask_array_type) for arg in args):
Expand Down Expand Up @@ -771,6 +737,7 @@ def apply_ufunc(
output_dtypes: Sequence = None,
output_sizes: Mapping[Any, int] = None,
meta: Any = None,
dask_gufunc_kwargs: Dict[str, Any] = None,
) -> Any:
"""Apply a vectorized function for unlabeled arrays on xarray objects.
Expand Down Expand Up @@ -857,19 +824,29 @@ def apply_ufunc(
dask arrays:
- 'forbidden' (default): raise an error if a dask array is encountered.
- 'allowed': pass dask arrays directly on to ``func``.
- 'allowed': pass dask arrays directly on to ``func``. Prefer this option if
``func`` natively supports dask arrays.
- 'parallelized': automatically parallelize ``func`` if any of the
inputs are a dask array. If used, the ``output_dtypes`` argument must
also be provided. Multiple output arguments are not yet supported.
inputs are a dask array by using `dask.array.apply_gufunc`. Multiple output
arguments are supported. Only use this option if ``func`` does not natively
support dask arrays (e.g. converts them to numpy arrays).
dask_gufunc_kwargs : dict, optional
Optional keyword arguments passed to ``dask.array.apply_gufunc`` if
dask='parallelized'. Possible keywords are ``output_sizes``, ``allow_rechunk``
and ``meta``.
output_dtypes : list of dtypes, optional
Optional list of output dtypes. Only used if dask='parallelized'.
Optional list of output dtypes. Only used if ``dask='parallelized'`` or
vectorize=True.
output_sizes : dict, optional
Optional mapping from dimension names to sizes for outputs. Only used
if dask='parallelized' and new dimensions (not found on inputs) appear
on outputs.
on outputs. ``output_sizes`` should be given in the ``dask_gufunc_kwargs``
parameter. It will be removed as direct parameter in a future version.
meta : optional
Size-0 object representing the type of array wrapped by dask array. Passed on to
``dask.array.blockwise``.
``dask.array.apply_gufunc``. ``meta`` should be given in the
``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter
a future version.
Returns
-------
Expand Down Expand Up @@ -1006,34 +983,45 @@ def earth_mover_distance(first_samples,
f"Please make {(exclude_dims - signature.all_core_dims)} a core dimension"
)

# handle dask_gufunc_kwargs
if dask == "parallelized":
if dask_gufunc_kwargs is None:
dask_gufunc_kwargs = {}
# todo: remove warnings after deprecation cycle
if meta is not None:
warnings.warn(
"``meta`` should be given in the ``dask_gufunc_kwargs`` parameter."
" It will be removed as direct parameter in a future version.",
FutureWarning,
stacklevel=2,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please set a class (DeprecationWarning) and stacklevel=2 on these warnings? That results in better messages for users.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to nitpick - shouldn't that be a FutureWarning so that users actually get to see it?

Copy link
Contributor Author

@kmuehlbauer kmuehlbauer Aug 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mathause At least in the tests the warnings are issued .

What's the actual difference between DeprecationWarning and FutureWarning (update: just found PendingDeprecationWarning)? And when should they be used? Just to know for future contributions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FutureWarning would be fine, too. We should probably try to come to consensus on a general policy for xarray.

The Python docs have some guidance but the overall recommendation is not really clear to me: https://docs.python.org/3/library/warnings.html#warning-categories

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FutureWarning is for users and DeprecationWarning for library authors (https://docs.python.org/3/library/warnings.html#warning-categories). Which is why you see DeprecationWarning in the test but won't when you execute the code. Took me a while to figure this out when I wanted to deprecate some stuff in my package.

import warnings

def test():
    warnings.warn("DeprecationWarning", DeprecationWarning)
    warnings.warn("FutureWarning", FutureWarning)

If you try this in ipython test() will raise both warnings. But if you save to a file and try

from test_warnings import test
test()

only FutureWarning will appear (I did not know this detail either https://www.python.org/dev/peps/pep-0565/).

Copy link
Contributor Author

@kmuehlbauer kmuehlbauer Aug 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mathause @shoyer I'll switch to FutureWarning since this seems to be the only user-visible warning, See https://www.python.org/dev/peps/pep-0565/#additional-use-case-for-futurewarning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, thanks for the pointers and explanations.

dask_gufunc_kwargs.setdefault("meta", meta)
if output_sizes is not None:
warnings.warn(
"``output_sizes`` should be given in the ``dask_gufunc_kwargs`` "
"parameter. It will be removed as direct parameter in a future "
"version.",
FutureWarning,
stacklevel=2,
)
dask_gufunc_kwargs.setdefault("output_sizes", output_sizes)

if kwargs:
func = functools.partial(func, **kwargs)

if vectorize:
if meta is None:
# set meta=np.ndarray by default for numpy vectorized functions
# work around dask bug computing meta with vectorized functions: GH5642
meta = np.ndarray

if signature.all_core_dims:
func = np.vectorize(
func, otypes=output_dtypes, signature=signature.to_gufunc_string()
)
else:
func = np.vectorize(func, otypes=output_dtypes)

variables_vfunc = functools.partial(
apply_variable_ufunc,
func,
signature=signature,
exclude_dims=exclude_dims,
keep_attrs=keep_attrs,
dask=dask,
vectorize=vectorize,
output_dtypes=output_dtypes,
output_sizes=output_sizes,
meta=meta,
dask_gufunc_kwargs=dask_gufunc_kwargs,
)

# feed groupby-apply_ufunc through apply_groupby_func
if any(isinstance(a, GroupBy) for a in args):
this_apply = functools.partial(
apply_ufunc,
Expand All @@ -1046,9 +1034,12 @@ def earth_mover_distance(first_samples,
dataset_fill_value=dataset_fill_value,
keep_attrs=keep_attrs,
dask=dask,
meta=meta,
vectorize=vectorize,
output_dtypes=output_dtypes,
dask_gufunc_kwargs=dask_gufunc_kwargs,
)
return apply_groupby_func(this_apply, *args)
# feed datasets apply_variable_ufunc through apply_dataset_vfunc
elif any(is_dict_like(a) for a in args):
return apply_dataset_vfunc(
variables_vfunc,
Expand All @@ -1060,6 +1051,7 @@ def earth_mover_distance(first_samples,
fill_value=dataset_fill_value,
keep_attrs=keep_attrs,
)
# feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
elif any(isinstance(a, DataArray) for a in args):
return apply_dataarray_vfunc(
variables_vfunc,
Expand All @@ -1069,9 +1061,11 @@ def earth_mover_distance(first_samples,
exclude_dims=exclude_dims,
keep_attrs=keep_attrs,
)
# feed Variables directly through apply_variable_ufunc
elif any(isinstance(a, Variable) for a in args):
return variables_vfunc(*args)
else:
# feed anything else through apply_array_ufunc
return apply_array_ufunc(func, *args, dask=dask)


Expand Down
Loading