-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: use dask.array.apply_gufunc
in xr.apply_ufunc
#4060
Changes from 42 commits
c18655f
3bf1d75
fff7660
d8bcb15
a17ca32
5f3f847
c9fd698
fc30ee8
820a4cd
9a266dc
76375b4
6cff763
eb953d2
4ee835d
ddeb1ea
f645bbd
c9e30af
bdfdd74
f9cb53c
3b652fb
13f5c1d
512d55d
5ad5063
9f51c77
41feeb3
ada9cf0
1bdf2bb
1401551
fb64ed9
b0b5e2e
d0b4fb2
1bca9de
1d5a8bb
91a9d5a
035aa17
e116fd0
4db4444
9897c51
d4902a6
5a1f15e
a05bd18
35ae2a9
2fc6272
4cb059e
4a48acd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,9 @@ | |
import functools | ||
import itertools | ||
import operator | ||
import warnings | ||
from collections import Counter | ||
from distutils.version import LooseVersion | ||
from typing import ( | ||
TYPE_CHECKING, | ||
AbstractSet, | ||
|
@@ -90,6 +92,12 @@ def all_core_dims(self): | |
self._all_core_dims = self.all_input_core_dims | self.all_output_core_dims | ||
return self._all_core_dims | ||
|
||
@property | ||
def dims_map(self): | ||
return { | ||
core_dim: f"dim{n}" for n, core_dim in enumerate(sorted(self.all_core_dims)) | ||
} | ||
|
||
@property | ||
def num_inputs(self): | ||
return len(self.input_core_dims) | ||
|
@@ -126,14 +134,12 @@ def to_gufunc_string(self): | |
Unlike __str__, handles dimensions that don't map to Python | ||
identifiers. | ||
""" | ||
all_dims = self.all_core_dims | ||
dims_map = dict(zip(sorted(all_dims), range(len(all_dims)))) | ||
input_core_dims = [ | ||
["dim%d" % dims_map[dim] for dim in core_dims] | ||
[self.dims_map[dim] for dim in core_dims] | ||
for core_dims in self.input_core_dims | ||
] | ||
output_core_dims = [ | ||
["dim%d" % dims_map[dim] for dim in core_dims] | ||
[self.dims_map[dim] for dim in core_dims] | ||
for core_dims in self.output_core_dims | ||
] | ||
alt_signature = type(self)(input_core_dims, output_core_dims) | ||
|
@@ -424,7 +430,7 @@ def apply_groupby_func(func, *args): | |
if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): | ||
raise ValueError( | ||
"apply_ufunc can only perform operations over " | ||
"multiple GroupBy objets at once if they are all " | ||
"multiple GroupBy objects at once if they are all " | ||
"grouped the same way" | ||
) | ||
|
||
|
@@ -539,16 +545,27 @@ def broadcast_compat_data( | |
return data | ||
|
||
|
||
def _vectorize(func, signature, output_dtypes): | ||
if signature.all_core_dims: | ||
func = np.vectorize( | ||
func, otypes=output_dtypes, signature=signature.to_gufunc_string() | ||
) | ||
else: | ||
func = np.vectorize(func, otypes=output_dtypes) | ||
|
||
return func | ||
|
||
|
||
def apply_variable_ufunc( | ||
func, | ||
*args, | ||
signature, | ||
exclude_dims=frozenset(), | ||
dask="forbidden", | ||
output_dtypes=None, | ||
output_sizes=None, | ||
vectorize=False, | ||
keep_attrs=False, | ||
meta=None, | ||
dask_gufunc_kwargs=None, | ||
): | ||
"""Apply a ndarray level function over Variable and/or ndarray objects. | ||
""" | ||
|
@@ -579,28 +596,61 @@ def apply_variable_ufunc( | |
"``.load()`` or ``.compute()``" | ||
) | ||
elif dask == "parallelized": | ||
input_dims = [broadcast_dims + dims for dims in signature.input_core_dims] | ||
numpy_func = func | ||
|
||
if dask_gufunc_kwargs is None: | ||
dask_gufunc_kwargs = {} | ||
|
||
output_sizes = dask_gufunc_kwargs.pop("output_sizes", {}) | ||
if output_sizes: | ||
output_sizes_renamed = {} | ||
for key, value in output_sizes.items(): | ||
if key not in signature.all_output_core_dims: | ||
raise ValueError( | ||
f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims" | ||
) | ||
output_sizes_renamed[signature.dims_map[key]] = value | ||
dask_gufunc_kwargs["output_sizes"] = output_sizes_renamed | ||
|
||
for key in signature.all_output_core_dims: | ||
if key not in signature.all_input_core_dims and key not in output_sizes: | ||
raise ValueError( | ||
f"dimension '{key}' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'" | ||
) | ||
|
||
def func(*arrays): | ||
return _apply_blockwise( | ||
import dask.array as da | ||
|
||
res = da.apply_gufunc( | ||
numpy_func, | ||
arrays, | ||
input_dims, | ||
output_dims, | ||
signature, | ||
output_dtypes, | ||
output_sizes, | ||
meta, | ||
signature.to_gufunc_string(), | ||
*arrays, | ||
vectorize=vectorize, | ||
output_dtypes=output_dtypes, | ||
**dask_gufunc_kwargs, | ||
) | ||
|
||
# todo: covers for https://github.com/dask/dask/pull/6207 | ||
# remove when minimal dask version >= 2.17.0 | ||
from dask import __version__ as dask_version | ||
|
||
if LooseVersion(dask_version) < LooseVersion("2.17.0"): | ||
if signature.num_outputs > 1: | ||
res = tuple(res) | ||
|
||
return res | ||
|
||
elif dask == "allowed": | ||
pass | ||
else: | ||
raise ValueError( | ||
"unknown setting for dask array handling in " | ||
"apply_ufunc: {}".format(dask) | ||
) | ||
else: | ||
if vectorize: | ||
func = _vectorize(func, signature, output_dtypes=output_dtypes) | ||
|
||
result_data = func(*input_data) | ||
|
||
if signature.num_outputs == 1: | ||
|
@@ -648,90 +698,6 @@ def func(*arrays): | |
return tuple(output) | ||
|
||
|
||
def _apply_blockwise( | ||
func, | ||
args, | ||
input_dims, | ||
output_dims, | ||
signature, | ||
output_dtypes, | ||
output_sizes=None, | ||
meta=None, | ||
): | ||
import dask.array | ||
|
||
if signature.num_outputs > 1: | ||
raise NotImplementedError( | ||
"multiple outputs from apply_ufunc not yet " | ||
"supported with dask='parallelized'" | ||
) | ||
|
||
if output_dtypes is None: | ||
raise ValueError( | ||
"output dtypes (output_dtypes) must be supplied to " | ||
"apply_func when using dask='parallelized'" | ||
) | ||
if not isinstance(output_dtypes, list): | ||
raise TypeError( | ||
"output_dtypes must be a list of objects coercible to " | ||
"numpy dtypes, got {}".format(output_dtypes) | ||
) | ||
if len(output_dtypes) != signature.num_outputs: | ||
raise ValueError( | ||
"apply_ufunc arguments output_dtypes and " | ||
"output_core_dims must have the same length: {} vs {}".format( | ||
len(output_dtypes), signature.num_outputs | ||
) | ||
) | ||
(dtype,) = output_dtypes | ||
|
||
if output_sizes is None: | ||
output_sizes = {} | ||
|
||
new_dims = signature.all_output_core_dims - signature.all_input_core_dims | ||
if any(dim not in output_sizes for dim in new_dims): | ||
raise ValueError( | ||
"when using dask='parallelized' with apply_ufunc, " | ||
"output core dimensions not found on inputs must " | ||
"have explicitly set sizes with ``output_sizes``: {}".format(new_dims) | ||
) | ||
|
||
for n, (data, core_dims) in enumerate(zip(args, signature.input_core_dims)): | ||
if isinstance(data, dask_array_type): | ||
# core dimensions cannot span multiple chunks | ||
for axis, dim in enumerate(core_dims, start=-len(core_dims)): | ||
if len(data.chunks[axis]) != 1: | ||
raise ValueError( | ||
"dimension {!r} on {}th function argument to " | ||
"apply_ufunc with dask='parallelized' consists of " | ||
"multiple chunks, but is also a core dimension. To " | ||
"fix, rechunk into a single dask array chunk along " | ||
"this dimension, i.e., ``.chunk({})``, but beware " | ||
"that this may significantly increase memory usage.".format( | ||
dim, n, {dim: -1} | ||
) | ||
) | ||
|
||
(out_ind,) = output_dims | ||
|
||
blockwise_args = [] | ||
for arg, dims in zip(args, input_dims): | ||
# skip leading dimensions that are implicitly added by broadcasting | ||
ndim = getattr(arg, "ndim", 0) | ||
trimmed_dims = dims[-ndim:] if ndim else () | ||
blockwise_args.extend([arg, trimmed_dims]) | ||
|
||
return dask.array.blockwise( | ||
func, | ||
out_ind, | ||
*blockwise_args, | ||
dtype=dtype, | ||
concatenate=True, | ||
new_axes=output_sizes, | ||
meta=meta, | ||
) | ||
|
||
|
||
def apply_array_ufunc(func, *args, dask="forbidden"): | ||
"""Apply a ndarray level function over ndarray objects.""" | ||
if any(isinstance(arg, dask_array_type) for arg in args): | ||
|
@@ -771,6 +737,7 @@ def apply_ufunc( | |
output_dtypes: Sequence = None, | ||
output_sizes: Mapping[Any, int] = None, | ||
meta: Any = None, | ||
dask_gufunc_kwargs: Dict[str, Any] = None, | ||
) -> Any: | ||
"""Apply a vectorized function for unlabeled arrays on xarray objects. | ||
|
||
|
@@ -857,19 +824,29 @@ def apply_ufunc( | |
dask arrays: | ||
|
||
- 'forbidden' (default): raise an error if a dask array is encountered. | ||
- 'allowed': pass dask arrays directly on to ``func``. | ||
- 'allowed': pass dask arrays directly on to ``func``. Prefer this option if | ||
``func`` natively supports dask arrays. | ||
- 'parallelized': automatically parallelize ``func`` if any of the | ||
inputs are a dask array. If used, the ``output_dtypes`` argument must | ||
also be provided. Multiple output arguments are not yet supported. | ||
inputs are a dask array by using `dask.array.apply_gufunc`. Multiple output | ||
arguments are supported. Only use this option if ``func`` does not natively | ||
support dask arrays (e.g. converts them to numpy arrays). | ||
dask_gufunc_kwargs : dict, optional | ||
Optional keyword arguments passed to ``dask.array.apply_gufunc`` if | ||
dask='parallelized'. Possible keywords are ``output_sizes``, ``allow_rechunk`` | ||
and ``meta``. | ||
output_dtypes : list of dtypes, optional | ||
Optional list of output dtypes. Only used if dask='parallelized'. | ||
Optional list of output dtypes. Only used if ``dask='parallelized'`` or | ||
vectorize=True. | ||
output_sizes : dict, optional | ||
Optional mapping from dimension names to sizes for outputs. Only used | ||
if dask='parallelized' and new dimensions (not found on inputs) appear | ||
on outputs. | ||
on outputs. ``output_sizes`` should be given in the ``dask_gufunc_kwargs`` | ||
parameter. It will be removed as direct parameter in a future version. | ||
meta : optional | ||
Size-0 object representing the type of array wrapped by dask array. Passed on to | ||
``dask.array.blockwise``. | ||
``dask.array.apply_gufunc``. ``meta`` should be given in the | ||
``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter | ||
a future version. | ||
|
||
Returns | ||
------- | ||
|
@@ -1006,34 +983,41 @@ def earth_mover_distance(first_samples, | |
f"Please make {(exclude_dims - signature.all_core_dims)} a core dimension" | ||
) | ||
|
||
# handle dask_gufunc_kwargs | ||
if dask == "parallelized": | ||
if dask_gufunc_kwargs is None: | ||
dask_gufunc_kwargs = {} | ||
# todo: remove warnings after deprecation cycle | ||
if meta is not None: | ||
warnings.warn( | ||
"``meta`` should be given in the ``dask_gufunc_kwargs`` parameter." | ||
" It will be removed as direct parameter in a future version." | ||
) | ||
dask_gufunc_kwargs.setdefault("meta", meta) | ||
if output_sizes is not None: | ||
warnings.warn( | ||
"``output_sizes`` should be given in the ``dask_gufunc_kwargs`` " | ||
"parameter. It will be removed as direct parameter in a future " | ||
"version." | ||
) | ||
dask_gufunc_kwargs.setdefault("output_sizes", output_sizes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I'm not sure I understand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Stephan refers to the problem you ran into with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be worth to combine Like output_core_dims=[[("sign", 2)]] or output_core_dims=[{"sign": 2}] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copy my comment over for completeness:
|
||
|
||
if kwargs: | ||
func = functools.partial(func, **kwargs) | ||
|
||
if vectorize: | ||
if meta is None: | ||
# set meta=np.ndarray by default for numpy vectorized functions | ||
# work around dask bug computing meta with vectorized functions: GH5642 | ||
meta = np.ndarray | ||
|
||
if signature.all_core_dims: | ||
func = np.vectorize( | ||
func, otypes=output_dtypes, signature=signature.to_gufunc_string() | ||
) | ||
else: | ||
func = np.vectorize(func, otypes=output_dtypes) | ||
|
||
variables_vfunc = functools.partial( | ||
apply_variable_ufunc, | ||
func, | ||
signature=signature, | ||
exclude_dims=exclude_dims, | ||
keep_attrs=keep_attrs, | ||
dask=dask, | ||
vectorize=vectorize, | ||
output_dtypes=output_dtypes, | ||
output_sizes=output_sizes, | ||
meta=meta, | ||
dask_gufunc_kwargs=dask_gufunc_kwargs, | ||
) | ||
|
||
# feed groupby-apply_ufunc through apply_groupby_func | ||
if any(isinstance(a, GroupBy) for a in args): | ||
this_apply = functools.partial( | ||
apply_ufunc, | ||
|
@@ -1046,9 +1030,12 @@ def earth_mover_distance(first_samples, | |
dataset_fill_value=dataset_fill_value, | ||
keep_attrs=keep_attrs, | ||
dask=dask, | ||
meta=meta, | ||
vectorize=vectorize, | ||
output_dtypes=output_dtypes, | ||
dask_gufunc_kwargs=dask_gufunc_kwargs, | ||
) | ||
return apply_groupby_func(this_apply, *args) | ||
# feed datasets apply_variable_ufunc through apply_dataset_vfunc | ||
elif any(is_dict_like(a) for a in args): | ||
return apply_dataset_vfunc( | ||
variables_vfunc, | ||
|
@@ -1060,6 +1047,7 @@ def earth_mover_distance(first_samples, | |
fill_value=dataset_fill_value, | ||
keep_attrs=keep_attrs, | ||
) | ||
# feed DataArray apply_variable_ufunc through apply_dataarray_vfunc | ||
elif any(isinstance(a, DataArray) for a in args): | ||
return apply_dataarray_vfunc( | ||
variables_vfunc, | ||
|
@@ -1069,9 +1057,11 @@ def earth_mover_distance(first_samples, | |
exclude_dims=exclude_dims, | ||
keep_attrs=keep_attrs, | ||
) | ||
# feed Variables directly through apply_variable_ufunc | ||
elif any(isinstance(a, Variable) for a in args): | ||
return variables_vfunc(*args) | ||
else: | ||
# feed anything else through apply_array_ufunc | ||
return apply_array_ufunc(func, *args, dask=dask) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please set a class (DeprecationWarning) and stacklevel=2 on these warnings? That results in better messages for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to nitpick - shouldn't that be a
FutureWarning
so that users actually get to see it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mathause At least in the tests the warnings are issued .
What's the actual difference between
DeprecationWarning
andFutureWarning
(update: just foundPendingDeprecationWarning
)? And when should they be used? Just to know for future contributions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FutureWarning would be fine, too. We should probably try to come to consensus on a general policy for xarray.
The Python docs have some guidance but the overall recommendation is not really clear to me: https://docs.python.org/3/library/warnings.html#warning-categories
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FutureWarning
is for users andDeprecationWarning
for library authors (https://docs.python.org/3/library/warnings.html#warning-categories). Which is why you seeDeprecationWarning
in the test but won't when you execute the code. Took me a while to figure this out when I wanted to deprecate some stuff in my package.If you try this in ipython
test()
will raise both warnings. But if you save to a file and tryonly
FutureWarning
will appear (I did not know this detail either https://www.python.org/dev/peps/pep-0565/).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mathause @shoyer I'll switch to
FutureWarning
since this seems to be the only user-visible warning, See https://www.python.org/dev/peps/pep-0565/#additional-use-case-for-futurewarningThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And, thanks for the pointers and explanations.