Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add template xarray object kwarg to map_blocks #3816

Merged
merged 29 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
064b583
Allow providing template dataset to map_blocks.
dcherian Feb 25, 2020
1f14b11
Update dimension shape check.
dcherian Feb 26, 2020
045ae2b
Allow user function to add new unindexed dimension.
dcherian Feb 26, 2020
0e37b63
Add docstring for template.
dcherian Feb 26, 2020
5704e84
renaming
dcherian Feb 26, 2020
2c458e1
Raise nice error if adding a new chunked dimension,
dcherian Feb 26, 2020
dced076
Raise nice error message when expected dimension is missing on return…
dcherian Feb 26, 2020
717d900
Revert "Allow user function to add new unindexed dimension."
dcherian Mar 2, 2020
42a9070
Add test + fix output_chunks for dataarray template
dcherian Mar 3, 2020
64ba31f
typing
dcherian Mar 3, 2020
a68cb41
fix test
dcherian Mar 3, 2020
0bc3754
Add nice error messages when result doesn't match template.
dcherian Mar 6, 2020
ee66c88
blacken
dcherian Mar 6, 2020
d52dfd6
Add template kwarg to DataArray.map_blocks & Dataset.map_blocks
dcherian Mar 19, 2020
8ef47f6
minor error message fixes.
dcherian Mar 19, 2020
376f242
docstring updates.
dcherian Mar 19, 2020
d9029eb
bugfix for expected shapes when template is not specified
dcherian Mar 19, 2020
6f69955
Add map_blocks docs.
dcherian Mar 19, 2020
da62535
Update doc/dask.rst
dcherian Mar 19, 2020
4a355e6
Merge remote-tracking branch 'upstream/master' into map-blocks-schema
dcherian Mar 28, 2020
66fe4c4
Merge branch 'map-blocks-schema' of github.com:dcherian/xarray into m…
dcherian Mar 28, 2020
085ce9a
Merge remote-tracking branch 'upstream/master' into map-blocks-schema
dcherian Apr 30, 2020
0f741e2
refactor out slicer for chunks
dcherian May 1, 2020
869e62d
Check expected index values.
dcherian May 1, 2020
334e8d3
Raise nice error when template object does not have required number o…
dcherian May 1, 2020
4dd699e
doc updates.
dcherian May 1, 2020
27eb873
more review comments.
dcherian May 1, 2020
04a2c3c
Mention that attrs are taken from template.
dcherian May 1, 2020
4b92168
Add test and explicit point out that attrs is copied from template
dcherian May 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 60 additions & 27 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def map_blocks(
obj: Union[DataArray, Dataset],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: T_DSorDA = None,
) -> T_DSorDA:
"""Apply a function to each chunk of a DataArray or Dataset. This function is
experimental and its signature may change.
Expand All @@ -119,14 +120,10 @@ def map_blocks(
corresponding to one chunk along each chunked dimension. ``func`` will be
executed as ``func(obj_subset, *args, **kwargs)``.

The function will be first run on mocked-up data, that looks like 'obj' but
has sizes 0, to determine properties of the returned object such as dtype,
variable names, new dimensions and new indexes (if any).

This function must return either a single DataArray or a single Dataset.

This function cannot change size of existing dimensions, or add new chunked
dimensions.
This function cannot add a new chunked dimension.

obj: DataArray, Dataset
Passed to the function as its first argument, one dask chunk at a time.
args: Sequence
Expand All @@ -135,6 +132,12 @@ def map_blocks(
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
template: (optional) DataArray, Dataset
xarray object representing the final result after compute is called. If not provided,
the function will be first run on mocked-up data, that looks like 'obj' but
has sizes 0, to determine properties of the returned object such as dtype,
variable names, new dimensions and new indexes (if any).
'template' must be provided if the function changes the size of existing dimensions.

Returns
-------
Expand Down Expand Up @@ -198,18 +201,26 @@ def map_blocks(
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""

def _wrapper(func, obj, to_array, args, kwargs):
def _wrapper(func, obj, to_array, args, kwargs, expected_shapes):
check_shapes = dict(obj.dims)
check_shapes.update(expected_shapes)

if to_array:
obj = dataset_to_dataarray(obj)

result = func(obj, *args, **kwargs)

missing_dimensions = set(expected_shapes) - set(result.sizes)
if missing_dimensions:
raise ValueError(
f"Dimensions {missing_dimensions} missing on returned object."
)

for name, index in result.indexes.items():
if name in obj.indexes:
if len(index) != len(obj.indexes[name]):
if name in check_shapes:
if len(index) != check_shapes[name]:
raise ValueError(
"Length of the %r dimension has changed. This is not allowed."
% name
f"Received dimension {name} of length {len(index)}. Expected length {expected_shapes[name]}."
)

return make_dict(result)
Expand Down Expand Up @@ -245,8 +256,24 @@ def _wrapper(func, obj, to_array, args, kwargs):
input_is_array = False

input_chunks = dataset.chunks
dataset_indexes = set(dataset.indexes)
if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, obj, *args, **kwargs)
template_indexes = set(template.indexes)
preserved_indexes = template_indexes & dataset_indexes
new_indexes = template_indexes - dataset_indexes
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
indexes.update({k: template.indexes[k] for k in new_indexes})
output_chunks = input_chunks

else:
# template xarray object has been provided with proper sizes and chunk shapes
dcherian marked this conversation as resolved.
Show resolved Hide resolved
template_indexes = set(template.indexes)
indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes}
indexes.update({k: template.indexes[k] for k in template_indexes})
output_chunks = template.chunks

template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs)
if isinstance(template, DataArray):
result_is_array = True
template_name = template.name
Expand All @@ -258,13 +285,6 @@ def _wrapper(func, obj, to_array, args, kwargs):
f"func output must be DataArray or Dataset; got {type(template)}"
)

template_indexes = set(template.indexes)
dataset_indexes = set(dataset.indexes)
preserved_indexes = template_indexes & dataset_indexes
new_indexes = template_indexes - dataset_indexes
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
indexes.update({k: template.indexes[k] for k in new_indexes})

# We're building a new HighLevelGraph hlg. We'll have one new layer
# for each variable in the dataset, which is the result of the
# func applied to the values.
Expand All @@ -284,7 +304,7 @@ def _wrapper(func, obj, to_array, args, kwargs):

# iterate over all possible chunk combinations
for v in itertools.product(*ichunk.values()):
chunk_index_dict = dict(zip(dataset.dims, v))
input_chunk_index = dict(zip(dataset.dims, v))

# this will become [[name1, variable1],
# [name2, variable2],
Expand All @@ -299,7 +319,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
# recursively index into dask_keys nested list to get chunk
chunk = variable.__dask_keys__()
for dim in variable.dims:
chunk = chunk[chunk_index_dict[dim]]
chunk = chunk[input_chunk_index[dim]]

chunk_variable_task = (f"{gname}-{chunk[0]}",) + v
graph[chunk_variable_task] = (
Expand All @@ -311,8 +331,8 @@ def _wrapper(func, obj, to_array, args, kwargs):
# index into variable appropriately
subsetter = {}
for dim in variable.dims:
if dim in chunk_index_dict:
which_chunk = chunk_index_dict[dim]
if dim in input_chunk_index:
which_chunk = input_chunk_index[dim]
subsetter[dim] = slice(
chunk_index_bounds[dim][which_chunk],
chunk_index_bounds[dim][which_chunk + 1],
Expand All @@ -333,6 +353,11 @@ def _wrapper(func, obj, to_array, args, kwargs):
else:
data_vars.append([name, chunk_variable_task])

# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
# expected_shapes is used to raise nice error messages in _wrapper
expected_shapes = {k: output_chunks[k][v] for k, v in input_chunk_index.items()}

from_wrapper = (gname,) + v
graph[from_wrapper] = (
_wrapper,
Expand All @@ -341,6 +366,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
input_is_array,
args,
kwargs,
expected_shapes,
)

# mapping from variable name to dask graph key
Expand All @@ -353,8 +379,12 @@ def _wrapper(func, obj, to_array, args, kwargs):

key: Tuple[Any, ...] = (gname_l,)
for dim in variable.dims:
if dim in chunk_index_dict:
key += (chunk_index_dict[dim],)
if dim in input_chunk_index:
key += (input_chunk_index[dim],)
elif dim in output_chunks:
raise ValueError(
f"Function is attempting to add a new chunked dimension {dim}. This is not allowed."
)
else:
# unchunked dimensions in the input have one chunk in the result
key += (0,)
Expand All @@ -379,10 +409,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
dims = template[name].dims
var_chunks = []
for dim in dims:
if dim in input_chunks:
var_chunks.append(input_chunks[dim])
if dim in output_chunks:
var_chunks.append(output_chunks[dim])
elif dim in indexes:
var_chunks.append((len(indexes[dim]),))
elif dim in template.dims:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# new unindexed dimension
var_chunks.append((template.sizes[dim],))

data = dask.array.Array(
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def test_map_blocks_error(map_da, map_ds):
def bad_func(darray):
return (darray * darray.x + 5 * darray.y)[:1, :1]

with raises_regex(ValueError, "Length of the.* has changed."):
with raises_regex(ValueError, "Received dimension.*"):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
xr.map_blocks(bad_func, map_da).compute()

def returns_numpy(darray):
Expand Down Expand Up @@ -1147,6 +1147,7 @@ def test_map_blocks_to_array(map_ds):
lambda x: x.to_dataset(),
lambda x: x.drop_vars("x"),
lambda x: x.expand_dims(k=[1, 2, 3]),
lambda x: x.expand_dims(k=3),
lambda x: x.assign_coords(new_coord=("y", x.y * 2)),
lambda x: x.astype(np.int32),
# TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da],
Expand All @@ -1167,6 +1168,7 @@ def test_map_blocks_da_transformations(func, map_da):
lambda x: x.drop_vars("a"),
lambda x: x.drop_vars("x"),
lambda x: x.expand_dims(k=[1, 2, 3]),
lambda x: x.expand_dims(k=3),
lambda x: x.rename({"a": "new1", "b": "new2"}),
# TODO: [lambda x: x.isel(x=1)],
],
Expand Down