diff --git a/doc/api.rst b/doc/api.rst index b37c84e7a81..8ec6843d24a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -173,6 +173,7 @@ Computation Dataset.quantile Dataset.differentiate Dataset.integrate + Dataset.map_blocks Dataset.polyfit **Aggregation**: @@ -358,6 +359,8 @@ Computation DataArray.integrate DataArray.polyfit DataArray.str + DataArray.map_blocks + **Aggregation**: :py:attr:`~DataArray.all` @@ -518,7 +521,6 @@ Dataset methods Dataset.load Dataset.chunk Dataset.unify_chunks - Dataset.map_blocks Dataset.filter_by_attrs Dataset.info @@ -550,7 +552,6 @@ DataArray methods DataArray.load DataArray.chunk DataArray.unify_chunks - DataArray.map_blocks Coordinates objects =================== diff --git a/doc/dask.rst b/doc/dask.rst index 2248de9c0d8..df223982ba4 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -284,12 +284,21 @@ loaded into Dask or not: .. _dask.automatic-parallelization: -Automatic parallelization -------------------------- +Automatic parallelization with ``apply_ufunc`` and ``map_blocks`` +----------------------------------------------------------------- Almost all of xarray's built-in operations work on Dask arrays. If you want to -use a function that isn't wrapped by xarray, one option is to extract Dask -arrays from xarray objects (``.data``) and use Dask directly. +use a function that isn't wrapped by xarray, and have it applied in parallel on +each block of your xarray object, you have three options: + +1. Extract Dask arrays from xarray objects (``.data``) and use Dask directly. +2. Use :py:func:`~xarray.apply_ufunc` to apply functions that consume and return NumPy arrays. +3. Use :py:func:`~xarray.map_blocks`, :py:meth:`Dataset.map_blocks` or :py:meth:`DataArray.map_blocks` + to apply functions that consume and return xarray objects. + + +``apply_ufunc`` +~~~~~~~~~~~~~~~ Another option is to use xarray's :py:func:`~xarray.apply_ufunc`, which can automate `embarrassingly parallel @@ -400,6 +409,103 @@ application. structure of a problem, unlike the generic speedups offered by ``dask='parallelized'``. + +``map_blocks`` +~~~~~~~~~~~~~~ + +Functions that consume and return xarray objects can be easily applied in parallel using :py:func:`map_blocks`. +Your function will receive an xarray Dataset or DataArray subset to one chunk +along each chunked dimension. + +.. ipython:: python + + ds.temperature + +This DataArray has 3 chunks each with length 10 along the time dimension. +At compute time, a function applied with :py:func:`map_blocks` will receive a DataArray corresponding to a single block of shape 10x180x180 +(time x latitude x longitude) with values loaded. The following snippet illustrates how to check the shape of the object +received by the applied function. + +.. ipython:: python + + def func(da): + print(da.sizes) + return da.time + + mapped = xr.map_blocks(func, ds.temperature) + mapped + +Notice that the :py:meth:`map_blocks` call printed +``Frozen({'time': 0, 'latitude': 0, 'longitude': 0})`` to screen. +``func`` is received 0-sized blocks! :py:meth:`map_blocks` needs to know what the final result +looks like in terms of dimensions, shapes etc. It does so by running the provided function on 0-shaped +inputs (*automated inference*). This works in many cases, but not all. If automatic inference does not +work for your function, provide the ``template`` kwarg (see below). + +In this case, automatic inference has worked so let's check that the result is as expected. + +.. ipython:: python + + mapped.load(scheduler="single-threaded") + mapped.identical(ds.time) + +Note that we use ``.load(scheduler="single-threaded")`` to execute the computation. +This executes the Dask graph in `serial` using a for loop, but allows for printing to screen and other +debugging techniques. We can easily see that our function is receiving blocks of shape 10x180x180 and +the returned result is identical to ``ds.time`` as expected. + + +Here is a common example where automated inference will not work. + +.. ipython:: python + :okexcept: + + def func(da): + print(da.sizes) + return da.isel(time=[1]) + + mapped = xr.map_blocks(func, ds.temperature) + +``func`` cannot be run on 0-shaped inputs because it is not possible to extract element 1 along a +dimension of size 0. In this case we need to tell :py:func:`map_blocks` what the returned result looks +like using the ``template`` kwarg. ``template`` must be an xarray Dataset or DataArray (depending on +what the function returns) with dimensions, shapes, chunk sizes, attributes, coordinate variables *and* data +variables that look exactly like the expected result. The variables should be dask-backed and hence not +incur much memory cost. + +.. note:: + + Note that when ``template`` is provided, ``attrs`` from ``template`` are copied over to the result. Any + ``attrs`` set in ``func`` will be ignored. + + +.. ipython:: python + + template = ds.temperature.isel(time=[1, 11, 21]) + mapped = xr.map_blocks(func, ds.temperature, template=template) + + +Notice that the 0-shaped sizes were not printed to screen. Since ``template`` has been provided +:py:func:`map_blocks` does not need to infer it by running ``func`` on 0-shaped inputs. + +.. ipython:: python + + mapped.identical(template) + + +:py:func:`map_blocks` also allows passing ``args`` and ``kwargs`` down to the user function ``func``. +``func`` will be executed as ``func(block_xarray, *args, **kwargs)`` so ``args`` must be a list and ``kwargs`` must be a dictionary. + +.. ipython:: python + + def func(obj, a, b=0): + return obj + a + b + + mapped = ds.map_blocks(func, args=[10], kwargs={"b": 10}) + expected = ds + 10 + 10 + mapped.identical(expected) + + Chunking and performance ------------------------ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 051a41a57e5..29f10d4c87f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,9 @@ New Features the :py:class:`~core.accessor_dt.DatetimeAccessor` (:pull:`3935`). This feature requires cftime version 1.1.0 or greater. By `Spencer Clark `_. +- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases + where the result of a computation could not be inferred automatically. + By `Deepak Cherian `_ Bug fixes ~~~~~~~~~ @@ -115,6 +118,8 @@ Documentation By `Matthias Riße `_. - Apply ``black`` to all the code in the documentation (:pull:`4012`) By `Justus Magin `_. +- Narrative documentation now describes :py:meth:`map_blocks`. :ref:`dask.automatic-parallelization`. + By `Deepak Cherian `_. Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5ced7e251c4..45eee2d89bf 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3260,27 +3260,25 @@ def map_blocks( func: "Callable[..., T_DSorDA]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this DataArray. This method is experimental - and its signature may change. + Apply a function to each block of this DataArray. + + .. warning:: + This method is experimental and its signature may change. Parameters ---------- func: callable - User-provided function that accepts a DataArray as its first parameter. The - function will receive a subset of this DataArray, corresponding to one chunk - along each chunked dimension. ``func`` will be executed as - ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like this array - but has sizes 0, to determine properties of the returned object such as - dtype, variable names, new dimensions and new indexes (if any). + User-provided function that accepts a DataArray as its first + parameter. The function will receive a subset, i.e. one block, of this DataArray + (see below), corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(block_subset, *args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. + This function cannot add a new chunked dimension. args: Sequence Passed verbatim to func after unpacking, after the sliced DataArray. xarray objects, if any, will not be split by chunks. Passing dask collections is @@ -3288,6 +3286,12 @@ def map_blocks( kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. + template: (optional) DataArray, Dataset + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + 'template' must be provided if the function changes the size of existing dimensions. Returns ------- @@ -3310,7 +3314,7 @@ def map_blocks( """ from .parallel import map_blocks - return map_blocks(func, self, args, kwargs) + return map_blocks(func, self, args, kwargs, template) def polyfit( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dd7871eaf3a..2ff585acb7f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5708,27 +5708,25 @@ def map_blocks( func: "Callable[..., T_DSorDA]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this Dataset. This method is experimental and - its signature may change. + Apply a function to each block of this Dataset. + + .. warning:: + This method is experimental and its signature may change. Parameters ---------- func: callable - User-provided function that accepts a Dataset as its first parameter. The - function will receive a subset of this Dataset, corresponding to one chunk - along each chunked dimension. ``func`` will be executed as - ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like this - Dataset but has sizes 0, to determine properties of the returned object such - as dtype, variable names, new dimensions and new indexes (if any). + User-provided function that accepts a Dataset as its first + parameter. The function will receive a subset, i.e. one block, of this Dataset + (see below), corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(block_subset, *args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. + This function cannot add a new chunked dimension. args: Sequence Passed verbatim to func after unpacking, after the sliced DataArray. xarray objects, if any, will not be split by chunks. Passing dask collections is @@ -5736,6 +5734,12 @@ def map_blocks( kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. + template: (optional) DataArray, Dataset + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + 'template' must be provided if the function changes the size of existing dimensions. Returns ------- @@ -5758,7 +5762,7 @@ def map_blocks( """ from .parallel import map_blocks - return map_blocks(func, self, args, kwargs) + return map_blocks(func, self, args, kwargs, template) def polyfit( self, @@ -5933,7 +5937,7 @@ def polyfit( "The number of data points must exceed order to scale the covariance matrix." ) fac = residuals / (x.shape[0] - order) - covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j"),) * fac + covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j")) * fac variables[name + "polyfit_covariance"] = covariance return Dataset(data_vars=variables, attrs=self.attrs.copy()) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 6f1668f698f..d91dfb4a275 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -31,6 +31,30 @@ T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +def check_result_variables( + result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str +): + + if kind == "coords": + nice_str = "coordinate" + elif kind == "data_vars": + nice_str = "data" + + # check that coords and data variables are as expected + missing = expected[kind] - set(getattr(result, kind)) + if missing: + raise ValueError( + "Result from applying user function does not contain " + f"{nice_str} variables {missing}." + ) + extra = set(getattr(result, kind)) - expected[kind] + if extra: + raise ValueError( + "Result from applying user function has unexpected " + f"{nice_str} variables {extra}." + ) + + def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): raise TypeError("Expected Dataset, got %s" % type(obj)) @@ -80,7 +104,8 @@ def infer_template( template = func(*meta_args, **kwargs) except Exception as e: raise Exception( - "Cannot infer object returned from running user provided function." + "Cannot infer object returned from running user provided function. " + "Please supply the 'template' kwarg to map_blocks." ) from e if not isinstance(template, (Dataset, DataArray)): @@ -102,14 +127,24 @@ def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: return {k: v.data for k, v in x.variables.items()} +def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): + if dim in chunk_index: + which_chunk = chunk_index[dim] + return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1]) + return slice(None) + + def map_blocks( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union[DataArray, Dataset] = None, ) -> T_DSorDA: - """Apply a function to each chunk of a DataArray or Dataset. This function is - experimental and its signature may change. + """Apply a function to each block of a DataArray or Dataset. + + .. warning:: + This function is experimental and its signature may change. Parameters ---------- @@ -119,14 +154,10 @@ def map_blocks( corresponding to one chunk along each chunked dimension. ``func`` will be executed as ``func(obj_subset, *args, **kwargs)``. - The function will be first run on mocked-up data, that looks like 'obj' but - has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). - This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. + This function cannot add a new chunked dimension. + obj: DataArray, Dataset Passed to the function as its first argument, one dask chunk at a time. args: Sequence @@ -135,6 +166,15 @@ def map_blocks( kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. + template: (optional) DataArray, Dataset + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + 'template' must be provided if the function changes the size of existing dimensions. + When provided, `attrs` on variables in `template` are copied over to the result. Any + `attrs` set by `func` will be ignored. + Returns ------- @@ -201,22 +241,47 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper(func, obj, to_array, args, kwargs): + def _wrapper(func, obj, to_array, args, kwargs, expected): + check_shapes = dict(obj.dims) + check_shapes.update(expected["shapes"]) + if to_array: obj = dataset_to_dataarray(obj) result = func(obj, *args, **kwargs) + # check all dims are present + missing_dimensions = set(expected["shapes"]) - set(result.sizes) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} missing on returned object." + ) + + # check that index lengths and values are as expected for name, index in result.indexes.items(): - if name in obj.indexes: - if len(index) != len(obj.indexes[name]): + if name in check_shapes: + if len(index) != check_shapes[name]: raise ValueError( - "Length of the %r dimension has changed. This is not allowed." - % name + f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." ) + if name in expected["indexes"]: + expected_index = expected["indexes"][name] + if not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) + + # check that all expected variables were returned + check_result_variables(result, expected, "coords") + if isinstance(result, Dataset): + check_result_variables(result, expected, "data_vars") return make_dict(result) + if template is not None and not isinstance(template, (DataArray, Dataset)): + raise TypeError( + f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." + ) if not isinstance(args, Sequence): raise TypeError("args must be a sequence (for example, a list or tuple).") if kwargs is None: @@ -248,8 +313,38 @@ def _wrapper(func, obj, to_array, args, kwargs): input_is_array = False input_chunks = dataset.chunks + dataset_indexes = set(dataset.indexes) + if template is None: + # infer template by providing zero-shaped arrays + template = infer_template(func, obj, *args, **kwargs) + template_indexes = set(template.indexes) + preserved_indexes = template_indexes & dataset_indexes + new_indexes = template_indexes - dataset_indexes + indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} + indexes.update({k: template.indexes[k] for k in new_indexes}) + output_chunks = { + dim: input_chunks[dim] for dim in template.dims if dim in input_chunks + } + + else: + # template xarray object has been provided with proper sizes and chunk shapes + template_indexes = set(template.indexes) + indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} + indexes.update({k: template.indexes[k] for k in template_indexes}) + if isinstance(template, DataArray): + output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore + else: + output_chunks = template.chunks # type: ignore + + for dim in output_chunks: + if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): + raise ValueError( + "map_blocks requires that one block of the input maps to one block of output. " + f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. " + f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or " + "fix the provided template." + ) - template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True template_name = template.name @@ -261,13 +356,6 @@ def _wrapper(func, obj, to_array, args, kwargs): f"func output must be DataArray or Dataset; got {type(template)}" ) - template_indexes = set(template.indexes) - dataset_indexes = set(dataset.indexes) - preserved_indexes = template_indexes & dataset_indexes - new_indexes = template_indexes - dataset_indexes - indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.indexes[k] for k in new_indexes}) - # We're building a new HighLevelGraph hlg. We'll have one new layer # for each variable in the dataset, which is the result of the # func applied to the values. @@ -281,13 +369,16 @@ def _wrapper(func, obj, to_array, args, kwargs): # map dims to list of chunk indexes ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} # mapping from chunk index to slice bounds - chunk_index_bounds = { + input_chunk_bounds = { dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() } + output_chunk_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() + } # iterate over all possible chunk combinations for v in itertools.product(*ichunk.values()): - chunk_index_dict = dict(zip(dataset.dims, v)) + chunk_index = dict(zip(dataset.dims, v)) # this will become [[name1, variable1], # [name2, variable2], @@ -302,9 +393,9 @@ def _wrapper(func, obj, to_array, args, kwargs): # recursively index into dask_keys nested list to get chunk chunk = variable.__dask_keys__() for dim in variable.dims: - chunk = chunk[chunk_index_dict[dim]] + chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{gname}-{chunk[0]}",) + v + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], @@ -312,15 +403,10 @@ def _wrapper(func, obj, to_array, args, kwargs): else: # non-dask array with possibly chunked dimensions # index into variable appropriately - subsetter = {} - for dim in variable.dims: - if dim in chunk_index_dict: - which_chunk = chunk_index_dict[dim] - subsetter[dim] = slice( - chunk_index_bounds[dim][which_chunk], - chunk_index_bounds[dim][which_chunk + 1], - ) - + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } subset = variable.isel(subsetter) chunk_variable_task = ( "{}-{}".format(gname, dask.base.tokenize(subset)), @@ -336,6 +422,20 @@ def _wrapper(func, obj, to_array, args, kwargs): else: data_vars.append([name, chunk_variable_task]) + # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper + expected = {} + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + expected["shapes"] = { + k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks + } + expected["data_vars"] = set(template.data_vars.keys()) # type: ignore + expected["coords"] = set(template.coords.keys()) # type: ignore + expected["indexes"] = { + dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + for dim in indexes + } + from_wrapper = (gname,) + v graph[from_wrapper] = ( _wrapper, @@ -344,6 +444,7 @@ def _wrapper(func, obj, to_array, args, kwargs): input_is_array, args, kwargs, + expected, ) # mapping from variable name to dask graph key @@ -356,10 +457,11 @@ def _wrapper(func, obj, to_array, args, kwargs): key: Tuple[Any, ...] = (gname_l,) for dim in variable.dims: - if dim in chunk_index_dict: - key += (chunk_index_dict[dim],) + if dim in chunk_index: + key += (chunk_index[dim],) else: # unchunked dimensions in the input have one chunk in the result + # output can have new dimensions with exactly one chunk key += (0,) # We're adding multiple new layers to the graph: @@ -382,8 +484,8 @@ def _wrapper(func, obj, to_array, args, kwargs): dims = template[name].dims var_chunks = [] for dim in dims: - if dim in input_chunks: - var_chunks.append(input_chunks[dim]) + if dim in output_chunks: + var_chunks.append(output_chunks[dim]) elif dim in indexes: var_chunks.append((len(indexes[dim]),)) elif dim in template.dims: diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 538dbbfb58b..75beb3757ca 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1039,7 +1039,7 @@ def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - with raises_regex(ValueError, "Length of the.* has changed."): + with raises_regex(ValueError, "Received dimension 'x' of length 1"): xr.map_blocks(bad_func, map_da).compute() def returns_numpy(darray): @@ -1109,6 +1109,11 @@ def add_attrs(obj): assert_identical(actual, expected) + # when template is specified, attrs are copied from template, not set by function + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj, template=obj) + assert_identical(actual, obj) + def test_map_blocks_change_name(map_da): def change_name(obj): @@ -1150,7 +1155,7 @@ def test_map_blocks_to_array(map_ds): lambda x: x.expand_dims(k=3), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), - # TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da], + lambda x: x.x, ], ) def test_map_blocks_da_transformations(func, map_da): @@ -1170,7 +1175,7 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.expand_dims(k=3), lambda x: x.rename({"a": "new1", "b": "new2"}), - # TODO: [lambda x: x.isel(x=1)], + lambda x: x.x, ], ) def test_map_blocks_ds_transformations(func, map_ds): @@ -1180,6 +1185,64 @@ def test_map_blocks_ds_transformations(func, map_ds): assert_identical(actual, func(map_ds)) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_da_ds_with_template(obj): + func = lambda x: x.isel(x=[1]) + template = obj.isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj, template=template) + assert_identical(actual, template) + + with raise_if_dask_computes(): + actual = obj.map_blocks(func, template=template) + assert_identical(actual, template) + + +def test_map_blocks_template_convert_object(): + da = make_da() + func = lambda x: x.to_dataset().isel(x=[1]) + template = da.to_dataset().isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, da, template=template) + assert_identical(actual, template) + + ds = da.to_dataset() + func = lambda x: x.to_array().isel(x=[1]) + template = ds.to_array().isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, ds, template=template) + assert_identical(actual, template) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_errors_bad_template(obj): + with raises_regex(ValueError, "unexpected coordinate variables"): + xr.map_blocks(lambda x: x.assign_coords(a=10), obj, template=obj).compute() + with raises_regex(ValueError, "does not contain coordinate variables"): + xr.map_blocks(lambda x: x.drop_vars("cxy"), obj, template=obj).compute() + with raises_regex(ValueError, "Dimensions {'x'} missing"): + xr.map_blocks(lambda x: x.isel(x=1), obj, template=obj).compute() + with raises_regex(ValueError, "Received dimension 'x' of length 1"): + xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute() + with raises_regex(TypeError, "must be a DataArray"): + xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=(obj,)).compute() + with raises_regex(ValueError, "map_blocks requires that one block"): + xr.map_blocks( + lambda x: x.isel(x=[1]).assign_coords(x=10), obj, template=obj.isel(x=[1]) + ).compute() + with raises_regex(ValueError, "Expected index 'x' to be"): + xr.map_blocks( + lambda a: a.isel(x=[1]).assign_coords(x=[120]), # assign bad index values + obj, + template=obj.isel(x=[1, 5, 9]), + ).compute() + + +def test_map_blocks_errors_bad_template_2(map_ds): + with raises_regex(ValueError, "unexpected data variables {'xyz'}"): + xr.map_blocks(lambda x: x.assign(xyz=1), map_ds, template=map_ds).compute() + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_object_method(obj): def func(obj):