Skip to content

Commit

Permalink
speed up map_blocks (#4149)
Browse files Browse the repository at this point in the history
* replace the object array with generator expressions and zip/enumerate

* remove a leftover grouping pair of parentheses

* reuse is_array instead of comparing again
  • Loading branch information
keewis authored Jun 12, 2020
1 parent 8f688ea commit 59a2397
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)


def to_object_array(iterable):
# using empty_like calls compute
npargs = np.empty((len(iterable),), dtype=np.object)
npargs[:] = iterable
return npargs
def unzip(iterable):
return zip(*iterable)


def assert_chunks_compatible(a: Dataset, b: Dataset):
Expand Down Expand Up @@ -335,23 +332,33 @@ def _wrapper(
if not dask.is_dask_collection(obj):
return func(obj, *args, **kwargs)

npargs = to_object_array([obj] + list(args))
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs]
is_array = [isinstance(arg, DataArray) for arg in npargs]
all_args = [obj] + list(args)
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args]
is_array = [isinstance(arg, DataArray) for arg in all_args]

# there should be a better way to group this. partition?
xarray_indices, xarray_objs = unzip(
(index, arg) for index, arg in enumerate(all_args) if is_xarray[index]
)
others = [
(index, arg) for index, arg in enumerate(all_args) if not is_xarray[index]
]

# all xarray objects must be aligned. This is consistent with apply_ufunc.
aligned = align(*npargs[is_xarray], join="exact")
# assigning to object arrays works better when RHS is object array
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
npargs[is_xarray] = to_object_array(aligned)
npargs[is_array] = to_object_array(
[dataarray_to_dataset(da) for da in npargs[is_array]]
aligned = align(*xarray_objs, join="exact")
xarray_objs = tuple(
dataarray_to_dataset(arg) if is_da else arg
for is_da, arg in zip(is_array, aligned)
)

_, npargs = unzip(
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
)

# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
input_indexes = dict(npargs[0].indexes)
for arg in npargs[1:][is_xarray[1:]]:
for arg in xarray_objs[1:]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(arg.indexes)
Expand Down

0 comments on commit 59a2397

Please sign in to comment.