From 064b583fcdaa250ce3c6bcca2e33d2e1eccad61a Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 14:45:14 -0700 Subject: [PATCH 01/26] Allow providing template dataset to map_blocks. --- xarray/core/parallel.py | 45 +++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index facfa06b23c..115a95d7b99 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -107,6 +107,7 @@ def map_blocks( obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: T_DSorDA = None, ) -> T_DSorDA: """Apply a function to each chunk of a DataArray or Dataset. This function is experimental and its signature may change. @@ -204,13 +205,14 @@ def _wrapper(func, obj, to_array, args, kwargs): result = func(obj, *args, **kwargs) - for name, index in result.indexes.items(): - if name in obj.indexes: - if len(index) != len(obj.indexes[name]): - raise ValueError( - "Length of the %r dimension has changed. This is not allowed." - % name - ) + # Make this check using the template so that we can raise nice error messages + # for name, index in result.indexes.items(): + # if name in obj.indexes: + # if len(index) != len(obj.indexes[name]): + # raise ValueError( + # "Length of the %r dimension has changed. This is not allowed." + # % name + # ) return make_dict(result) @@ -245,8 +247,24 @@ def _wrapper(func, obj, to_array, args, kwargs): input_is_array = False input_chunks = dataset.chunks + dataset_indexes = set(dataset.indexes) + if template is None: + # infer template by providing zero-shaped arrays + template = infer_template(func, obj, *args, **kwargs) + template_indexes = set(template.indexes) + preserved_indexes = template_indexes & dataset_indexes + new_indexes = template_indexes - dataset_indexes + indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} + indexes.update({k: template.indexes[k] for k in new_indexes}) + output_chunks = input_chunks + + else: + # template xarray object has been provided with proper sizes and chunk shapes + template_indexes = set(template.indexes) + indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} + indexes.update({k: template.indexes[k] for k in template_indexes}) + output_chunks = template.chunks - template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True template_name = template.name @@ -258,13 +276,6 @@ def _wrapper(func, obj, to_array, args, kwargs): f"func output must be DataArray or Dataset; got {type(template)}" ) - template_indexes = set(template.indexes) - dataset_indexes = set(dataset.indexes) - preserved_indexes = template_indexes & dataset_indexes - new_indexes = template_indexes - dataset_indexes - indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.indexes[k] for k in new_indexes}) - # We're building a new HighLevelGraph hlg. We'll have one new layer # for each variable in the dataset, which is the result of the # func applied to the values. @@ -379,8 +390,8 @@ def _wrapper(func, obj, to_array, args, kwargs): dims = template[name].dims var_chunks = [] for dim in dims: - if dim in input_chunks: - var_chunks.append(input_chunks[dim]) + if dim in output_chunks: + var_chunks.append(output_chunks[dim]) elif dim in indexes: var_chunks.append((len(indexes[dim]),)) From 1f14b11d7a4104679a75970f0e32520f25935818 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 20:34:46 -0700 Subject: [PATCH 02/26] Update dimension shape check. This accounts for dimension sizes being changed by the applied function. --- xarray/core/parallel.py | 25 ++++++++++++++++--------- xarray/tests/test_dask.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 115a95d7b99..17f13576417 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -199,20 +199,21 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper(func, obj, to_array, args, kwargs): + def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): + check_shapes = dict(obj.dims) + check_shapes.update(expected_shapes) + if to_array: obj = dataset_to_dataarray(obj) result = func(obj, *args, **kwargs) - # Make this check using the template so that we can raise nice error messages - # for name, index in result.indexes.items(): - # if name in obj.indexes: - # if len(index) != len(obj.indexes[name]): - # raise ValueError( - # "Length of the %r dimension has changed. This is not allowed." - # % name - # ) + for name, index in result.indexes.items(): + if name in check_shapes: + if len(index) != check_shapes[name]: + raise ValueError( + f"Received dimension {name} of length {len(index)}. Expected length {expected_shapes[name]}." + ) return make_dict(result) @@ -344,6 +345,11 @@ def _wrapper(func, obj, to_array, args, kwargs): else: data_vars.append([name, chunk_variable_task]) + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + # expected_shapes is used to raise nice error messages in _wrapper + expected_shapes = {k: output_chunks[k][v] for k, v in input_chunk_index.items()} + from_wrapper = (gname,) + v graph[from_wrapper] = ( _wrapper, @@ -352,6 +358,7 @@ def _wrapper(func, obj, to_array, args, kwargs): input_is_array, args, kwargs, + expected_shapes, ) # mapping from variable name to dask graph key diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8fb54c4ee84..857b37672b7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1039,7 +1039,7 @@ def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - with raises_regex(ValueError, "Length of the.* has changed."): + with raises_regex(ValueError, "Received dimension.*"): xr.map_blocks(bad_func, map_da).compute() def returns_numpy(darray): From 045ae2b1bf939515e0a38c960d0cdc7974bcfa37 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 20:35:46 -0700 Subject: [PATCH 03/26] Allow user function to add new unindexed dimension. --- xarray/core/parallel.py | 3 +++ xarray/tests/test_dask.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 17f13576417..2f4de742c4a 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -401,6 +401,9 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): var_chunks.append(output_chunks[dim]) elif dim in indexes: var_chunks.append((len(indexes[dim]),)) + elif dim in template.dims: + # new unindexed dimension + var_chunks.append((template.sizes[dim],)) data = dask.array.Array( hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 857b37672b7..76920a1498a 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1147,6 +1147,7 @@ def test_map_blocks_to_array(map_ds): lambda x: x.to_dataset(), lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.expand_dims(k=3), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), # TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da], @@ -1167,6 +1168,7 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.drop_vars("a"), lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.expand_dims(k=3), lambda x: x.rename({"a": "new1", "b": "new2"}), # TODO: [lambda x: x.isel(x=1)], ], From 0e37b63520e018d25a2658b061d42af3d1490c5b Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 20:41:18 -0700 Subject: [PATCH 04/26] Add docstring for template. --- xarray/core/parallel.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2f4de742c4a..ad3a10b8308 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -120,14 +120,8 @@ def map_blocks( corresponding to one chunk along each chunked dimension. ``func`` will be executed as ``func(obj_subset, *args, **kwargs)``. - The function will be first run on mocked-up data, that looks like 'obj' but - has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). - This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. obj: DataArray, Dataset Passed to the function as its first argument, one dask chunk at a time. args: Sequence @@ -136,6 +130,13 @@ def map_blocks( kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. + template: (optional) DataArray, Dataset + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + 'template' must be provided if the function changes the size of existing dimensions, + or adds new chunked dimensions. Returns ------- From 5704e843519379fedcb72748d94509cc53469074 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 20:59:38 -0700 Subject: [PATCH 05/26] renaming --- xarray/core/parallel.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index ad3a10b8308..742a9b8f1a4 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -297,7 +297,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): # iterate over all possible chunk combinations for v in itertools.product(*ichunk.values()): - chunk_index_dict = dict(zip(dataset.dims, v)) + input_chunk_index = dict(zip(dataset.dims, v)) # this will become [[name1, variable1], # [name2, variable2], @@ -312,7 +312,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): # recursively index into dask_keys nested list to get chunk chunk = variable.__dask_keys__() for dim in variable.dims: - chunk = chunk[chunk_index_dict[dim]] + chunk = chunk[input_chunk_index[dim]] chunk_variable_task = (f"{gname}-{chunk[0]}",) + v graph[chunk_variable_task] = ( @@ -324,8 +324,8 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): # index into variable appropriately subsetter = {} for dim in variable.dims: - if dim in chunk_index_dict: - which_chunk = chunk_index_dict[dim] + if dim in input_chunk_index: + which_chunk = input_chunk_index[dim] subsetter[dim] = slice( chunk_index_bounds[dim][which_chunk], chunk_index_bounds[dim][which_chunk + 1], @@ -372,8 +372,8 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): key: Tuple[Any, ...] = (gname_l,) for dim in variable.dims: - if dim in chunk_index_dict: - key += (chunk_index_dict[dim],) + if dim in input_chunk_index: + key += (input_chunk_index[dim],) else: # unchunked dimensions in the input have one chunk in the result key += (0,) From 2c458e1df8c7d9802c51675b872922c2828665b9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 21:08:44 -0700 Subject: [PATCH 06/26] Raise nice error if adding a new chunked dimension, --- xarray/core/parallel.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 742a9b8f1a4..75ad566e315 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -122,6 +122,8 @@ def map_blocks( This function must return either a single DataArray or a single Dataset. + This function cannot add a new chunked dimension. + obj: DataArray, Dataset Passed to the function as its first argument, one dask chunk at a time. args: Sequence @@ -135,8 +137,7 @@ def map_blocks( the function will be first run on mocked-up data, that looks like 'obj' but has sizes 0, to determine properties of the returned object such as dtype, variable names, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions, - or adds new chunked dimensions. + 'template' must be provided if the function changes the size of existing dimensions. Returns ------- @@ -374,6 +375,10 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): for dim in variable.dims: if dim in input_chunk_index: key += (input_chunk_index[dim],) + elif dim in output_chunks: + raise ValueError( + f"Function is attempting to add a new chunked dimension {dim}. This is not allowed." + ) else: # unchunked dimensions in the input have one chunk in the result key += (0,) From dced076fb81cda6f78369c8cab2521e76b610722 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Feb 2020 21:30:16 -0700 Subject: [PATCH 07/26] Raise nice error message when expected dimension is missing on returned object --- xarray/core/parallel.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 75ad566e315..d8306da4a9b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -210,6 +210,12 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): result = func(obj, *args, **kwargs) + missing_dimensions = set(expected_shapes) - set(result.sizes) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} missing on returned object." + ) + for name, index in result.indexes.items(): if name in check_shapes: if len(index) != check_shapes[name]: From 717d9000c37e539e7e8764c081376f1447860218 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 2 Mar 2020 18:39:41 +0530 Subject: [PATCH 08/26] Revert "Allow user function to add new unindexed dimension." This reverts commit 045ae2b1bf939515e0a38c960d0cdc7974bcfa37. --- xarray/core/parallel.py | 3 --- xarray/tests/test_dask.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index d8306da4a9b..3f912da4620 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -413,9 +413,6 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): var_chunks.append(output_chunks[dim]) elif dim in indexes: var_chunks.append((len(indexes[dim]),)) - elif dim in template.dims: - # new unindexed dimension - var_chunks.append((template.sizes[dim],)) data = dask.array.Array( hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 76920a1498a..857b37672b7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1147,7 +1147,6 @@ def test_map_blocks_to_array(map_ds): lambda x: x.to_dataset(), lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), - lambda x: x.expand_dims(k=3), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), # TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da], @@ -1168,7 +1167,6 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.drop_vars("a"), lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), - lambda x: x.expand_dims(k=3), lambda x: x.rename({"a": "new1", "b": "new2"}), # TODO: [lambda x: x.isel(x=1)], ], From 42a90701e06e43e3149c90f979c64d14eb5fb023 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 3 Mar 2020 22:22:47 +0530 Subject: [PATCH 09/26] Add test + fix output_chunks for dataarray template --- xarray/core/parallel.py | 5 ++++- xarray/tests/test_dask.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 3f912da4620..ce1e98bbfef 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -272,7 +272,10 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): template_indexes = set(template.indexes) indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} indexes.update({k: template.indexes[k] for k in template_indexes}) - output_chunks = template.chunks + if isinstance(template, DataArray): + output_chunks = dict(zip(template.dims, template.chunks)) + else: + output_chunks = template.chunks if isinstance(template, DataArray): result_is_array = True diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 857b37672b7..361f68f9afc 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1149,7 +1149,6 @@ def test_map_blocks_to_array(map_ds): lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), - # TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da], ], ) def test_map_blocks_da_transformations(func, map_da): @@ -1168,7 +1167,6 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.rename({"a": "new1", "b": "new2"}), - # TODO: [lambda x: x.isel(x=1)], ], ) def test_map_blocks_ds_transformations(func, map_ds): @@ -1178,6 +1176,18 @@ def test_map_blocks_ds_transformations(func, map_ds): assert_identical(actual, func(map_ds)) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_da_ds_with_template(obj): + func = lambda x: x.isel(x=[1]) + template = obj.isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj, template=template) + assert_identical(actual, template) + + with raises_regex(ValueError, "Dimensions {'x'} missing"): + xr.map_blocks(lambda x: x.isel(x=1), obj, template=template).compute() + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_object_method(obj): def func(obj): From 64ba31feb4626e0d2d0c8f00be9ba2cd65ecea60 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 3 Mar 2020 22:31:44 +0530 Subject: [PATCH 10/26] typing --- xarray/core/parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index ce1e98bbfef..60d29a5ec70 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -107,7 +107,7 @@ def map_blocks( obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, - template: T_DSorDA = None, + template: Union[DataArray, Dataset] = None, ) -> T_DSorDA: """Apply a function to each chunk of a DataArray or Dataset. This function is experimental and its signature may change. @@ -273,7 +273,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} indexes.update({k: template.indexes[k] for k in template_indexes}) if isinstance(template, DataArray): - output_chunks = dict(zip(template.dims, template.chunks)) + output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: output_chunks = template.chunks From a68cb4174f73d138095efe63f84e5f39de225fb3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 3 Mar 2020 22:57:23 +0530 Subject: [PATCH 11/26] fix test --- xarray/core/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 60d29a5ec70..71cd8ce9b30 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -324,7 +324,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): for dim in variable.dims: chunk = chunk[input_chunk_index[dim]] - chunk_variable_task = (f"{gname}-{chunk[0]}",) + v + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], From 0bc375429549056327fa854935498fd7a2b7453f Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 6 Mar 2020 17:24:33 +0530 Subject: [PATCH 12/26] Add nice error messages when result doesn't match template. --- xarray/core/parallel.py | 49 +++++++++++++++++++++++++++++++++------ xarray/tests/test_dask.py | 15 +++++++++++- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 71cd8ce9b30..5e87fb98e29 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -31,6 +31,30 @@ T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +def check_result_variables( + result: Union[DataArray, Dataset], expected: dict, kind: str +): + + if kind == "coords": + nice_str = "coordinate" + elif kind == "data_vars": + nice_str = "data" + + # check that coords and data variables are as expected + missing = expected[kind] - set(getattr(result, kind)) + if missing: + raise ValueError( + "Result from applying user function does not contain " + f"{nice_str} variables {missing}." + ) + extra = set(getattr(result, kind)) - expected[kind] + if extra: + raise ValueError( + "Result from applying user function has unexpected " + f"{nice_str} variables {extra}." + ) + + def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): raise TypeError("Expected Dataset, got %s" % type(obj)) @@ -201,28 +225,34 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): + def _wrapper(func, obj, to_array, args, kwargs, expected): check_shapes = dict(obj.dims) - check_shapes.update(expected_shapes) + check_shapes.update(expected["shapes"]) if to_array: obj = dataset_to_dataarray(obj) result = func(obj, *args, **kwargs) - missing_dimensions = set(expected_shapes) - set(result.sizes) + # check all dims are present + missing_dimensions = set(expected["shapes"]) - set(result.sizes) if missing_dimensions: raise ValueError( f"Dimensions {missing_dimensions} missing on returned object." ) + # check that index lengths are as expected for name, index in result.indexes.items(): if name in check_shapes: if len(index) != check_shapes[name]: raise ValueError( - f"Received dimension {name} of length {len(index)}. Expected length {expected_shapes[name]}." + f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." ) + check_result_variables(result, expected, "coords") + if isinstance(result, Dataset): + check_result_variables(result, expected, "data_vars") + return make_dict(result) if not isinstance(args, Sequence): @@ -356,10 +386,15 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): else: data_vars.append([name, chunk_variable_task]) + # expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper + expected = {} # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function - # expected_shapes is used to raise nice error messages in _wrapper - expected_shapes = {k: output_chunks[k][v] for k, v in input_chunk_index.items()} + expected["shapes"] = { + k: output_chunks[k][v] for k, v in input_chunk_index.items() + } + expected["data_vars"] = set(template.data_vars.keys()) # type: ignore + expected["coords"] = set(template.coords.keys()) # type: ignore from_wrapper = (gname,) + v graph[from_wrapper] = ( @@ -369,7 +404,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected_shapes): input_is_array, args, kwargs, - expected_shapes, + expected, ) # mapping from variable name to dask graph key diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 361f68f9afc..3d268fdbe8e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1184,8 +1184,21 @@ def test_map_blocks_da_ds_with_template(obj): actual = xr.map_blocks(func, obj, template=template) assert_identical(actual, template) + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_errors_bad_template(obj): + with raises_regex(ValueError, "unexpected coordinate variables"): + xr.map_blocks(lambda x: x.assign_coords(a=10), obj, template=obj).compute() + with raises_regex(ValueError, "does not contain coordinate variables"): + xr.map_blocks(lambda x: x.drop_vars("cxy"), obj, template=obj).compute() with raises_regex(ValueError, "Dimensions {'x'} missing"): - xr.map_blocks(lambda x: x.isel(x=1), obj, template=template).compute() + xr.map_blocks(lambda x: x.isel(x=1), obj, template=obj).compute() + with raises_regex(ValueError, "Received dimension 'x' of length 1"): + xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute() + +def test_map_blocks_errors_bad_template_2(map_ds): + with raises_regex(ValueError, "unexpected data variables {'xyz'}"): + xr.map_blocks(lambda x: x.assign(xyz=1), map_ds, template=map_ds).compute() @pytest.mark.parametrize("obj", [make_da(), make_ds()]) From ee66c8871175181b81209784e0b9191ed7719373 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 6 Mar 2020 17:32:01 +0530 Subject: [PATCH 13/26] blacken --- xarray/tests/test_dask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 3d268fdbe8e..1fe8c79808b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1196,6 +1196,7 @@ def test_map_blocks_errors_bad_template(obj): with raises_regex(ValueError, "Received dimension 'x' of length 1"): xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute() + def test_map_blocks_errors_bad_template_2(map_ds): with raises_regex(ValueError, "unexpected data variables {'xyz'}"): xr.map_blocks(lambda x: x.assign(xyz=1), map_ds, template=map_ds).compute() From d52dfd601d3dbec806ea02bc44d814612d8d2241 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Mar 2020 05:23:28 -0600 Subject: [PATCH 14/26] Add template kwarg to DataArray.map_blocks & Dataset.map_blocks --- xarray/core/dataarray.py | 24 +++++++++++++----------- xarray/core/dataset.py | 24 +++++++++++++----------- xarray/tests/test_dask.py | 4 ++++ 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 062cc6342df..59ab3177c6b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3182,6 +3182,7 @@ def map_blocks( func: "Callable[..., T_DSorDA]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ Apply a function to each chunk of this DataArray. This method is experimental @@ -3190,19 +3191,14 @@ def map_blocks( Parameters ---------- func: callable - User-provided function that accepts a DataArray as its first parameter. The - function will receive a subset of this DataArray, corresponding to one chunk - along each chunked dimension. ``func`` will be executed as - ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like this array - but has sizes 0, to determine properties of the returned object such as - dtype, variable names, new dimensions and new indexes (if any). + User-provided function that accepts a DataArray or Dataset as its first + parameter. The function will receive a subset of 'obj' (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(obj_subset, *args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. + This function cannot add a new chunked dimension. args: Sequence Passed verbatim to func after unpacking, after the sliced DataArray. xarray objects, if any, will not be split by chunks. Passing dask collections is @@ -3210,6 +3206,12 @@ def map_blocks( kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. + template: (optional) DataArray, Dataset + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + 'template' must be provided if the function changes the size of existing dimensions. Returns ------- @@ -3232,7 +3234,7 @@ def map_blocks( """ from .parallel import map_blocks - return map_blocks(func, self, args, kwargs) + return map_blocks(func, self, args, kwargs, template) # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7252dd2f3df..7c6f58d8e00 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5642,6 +5642,7 @@ def map_blocks( func: "Callable[..., T_DSorDA]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ Apply a function to each chunk of this Dataset. This method is experimental and @@ -5650,19 +5651,14 @@ def map_blocks( Parameters ---------- func: callable - User-provided function that accepts a Dataset as its first parameter. The - function will receive a subset of this Dataset, corresponding to one chunk - along each chunked dimension. ``func`` will be executed as - ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like this - Dataset but has sizes 0, to determine properties of the returned object such - as dtype, variable names, new dimensions and new indexes (if any). + User-provided function that accepts a DataArray or Dataset as its first + parameter. The function will receive a subset of 'obj' (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(obj_subset, *args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. + This function cannot add a new chunked dimension. args: Sequence Passed verbatim to func after unpacking, after the sliced DataArray. xarray objects, if any, will not be split by chunks. Passing dask collections is @@ -5670,6 +5666,12 @@ def map_blocks( kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be split by chunks. Passing dask collections is not allowed. + template: (optional) DataArray, Dataset + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like 'obj' but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, new dimensions and new indexes (if any). + 'template' must be provided if the function changes the size of existing dimensions. Returns ------- @@ -5692,7 +5694,7 @@ def map_blocks( """ from .parallel import map_blocks - return map_blocks(func, self, args, kwargs) + return map_blocks(func, self, args, kwargs, template) ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 1fe8c79808b..69463f95d89 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1184,6 +1184,10 @@ def test_map_blocks_da_ds_with_template(obj): actual = xr.map_blocks(func, obj, template=template) assert_identical(actual, template) + with raise_if_dask_computes(): + actual = obj.map_blocks(func, template=template) + assert_identical(actual, template) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_errors_bad_template(obj): From 8ef47f6430a74219cafb3113a0481af022a98e2f Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Mar 2020 05:35:13 -0600 Subject: [PATCH 15/26] minor error message fixes. --- xarray/core/parallel.py | 3 ++- xarray/tests/test_dask.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 5e87fb98e29..0770168a7aa 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -104,7 +104,8 @@ def infer_template( template = func(*meta_args, **kwargs) except Exception as e: raise Exception( - "Cannot infer object returned from running user provided function." + "Cannot infer object returned from running user provided function. " + "Please supply the 'template' kwarg to map_blocks." ) from e if not isinstance(template, (Dataset, DataArray)): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 69463f95d89..866ddeedac5 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1039,7 +1039,7 @@ def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - with raises_regex(ValueError, "Received dimension.*"): + with raises_regex(ValueError, "Received dimension 'x' of length 1"): xr.map_blocks(bad_func, map_da).compute() def returns_numpy(darray): From 376f24220d056d16e928124e19c9ef84646dfa02 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Mar 2020 05:42:40 -0600 Subject: [PATCH 16/26] docstring updates. --- xarray/core/dataarray.py | 14 ++++++++------ xarray/core/dataset.py | 14 ++++++++------ xarray/core/parallel.py | 6 ++++-- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 59ab3177c6b..e3c81541760 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3185,16 +3185,18 @@ def map_blocks( template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this DataArray. This method is experimental - and its signature may change. + Apply a function to each block of this DataArray. + + .. warning:: + This method is experimental and its signature may change. Parameters ---------- func: callable - User-provided function that accepts a DataArray or Dataset as its first - parameter. The function will receive a subset of 'obj' (see below), - corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(obj_subset, *args, **kwargs)``. + User-provided function that accepts a DataArray as its first + parameter. The function will receive a subset, i.e. one block, of this DataArray + (see below), corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(block_subset, *args, **kwargs)``. This function must return either a single DataArray or a single Dataset. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7c6f58d8e00..11db1c7006b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5645,16 +5645,18 @@ def map_blocks( template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this Dataset. This method is experimental and - its signature may change. + Apply a function to each block of this Dataset. + + .. warning:: + This method is experimental and its signature may change. Parameters ---------- func: callable - User-provided function that accepts a DataArray or Dataset as its first - parameter. The function will receive a subset of 'obj' (see below), - corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(obj_subset, *args, **kwargs)``. + User-provided function that accepts a Dataset as its first + parameter. The function will receive a subset, i.e. one block, of this Dataset + (see below), corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(block_subset, *args, **kwargs)``. This function must return either a single DataArray or a single Dataset. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 0770168a7aa..48fc23f1cec 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -134,8 +134,10 @@ def map_blocks( kwargs: Mapping[str, Any] = None, template: Union[DataArray, Dataset] = None, ) -> T_DSorDA: - """Apply a function to each chunk of a DataArray or Dataset. This function is - experimental and its signature may change. + """Apply a function to each block of a DataArray or Dataset. + + .. warning:: + This function is experimental and its signature may change. Parameters ---------- From d9029ebcd85db884df160e320e6b21e80750958f Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Mar 2020 06:20:07 -0600 Subject: [PATCH 17/26] bugfix for expected shapes when template is not specified --- xarray/core/parallel.py | 10 +++++++--- xarray/tests/test_dask.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 48fc23f1cec..29ac84c6df5 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -298,7 +298,9 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): new_indexes = template_indexes - dataset_indexes indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) - output_chunks = input_chunks + output_chunks = { + dim: input_chunks[dim] for dim in template.dims if dim in input_chunks + } else: # template xarray object has been provided with proper sizes and chunk shapes @@ -308,7 +310,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: - output_chunks = template.chunks + output_chunks = template.chunks # type: ignore if isinstance(template, DataArray): result_is_array = True @@ -394,7 +396,9 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function expected["shapes"] = { - k: output_chunks[k][v] for k, v in input_chunk_index.items() + k: output_chunks[k][v] + for k, v in input_chunk_index.items() + if k in output_chunks } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore expected["coords"] = set(template.coords.keys()) # type: ignore diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 866ddeedac5..5bcc5d37f44 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1149,6 +1149,7 @@ def test_map_blocks_to_array(map_ds): lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), + lambda x: x.x, ], ) def test_map_blocks_da_transformations(func, map_da): @@ -1167,6 +1168,7 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.rename({"a": "new1", "b": "new2"}), + lambda x: x.x, ], ) def test_map_blocks_ds_transformations(func, map_ds): From 6f699556e509e79b9af52da2c675169d0da0f852 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 19 Mar 2020 07:01:21 -0600 Subject: [PATCH 18/26] Add map_blocks docs. --- doc/api.rst | 5 ++- doc/dask.rst | 106 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 4492d882355..8514dff8264 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -171,6 +171,7 @@ Computation Dataset.quantile Dataset.differentiate Dataset.integrate + Dataset.map_blocks **Aggregation**: :py:attr:`~Dataset.all` @@ -350,6 +351,8 @@ Computation DataArray.differentiate DataArray.integrate DataArray.str + DataArray.map_blocks + **Aggregation**: :py:attr:`~DataArray.all` @@ -507,7 +510,6 @@ Dataset methods Dataset.load Dataset.chunk Dataset.unify_chunks - Dataset.map_blocks Dataset.filter_by_attrs Dataset.info @@ -539,7 +541,6 @@ DataArray methods DataArray.load DataArray.chunk DataArray.unify_chunks - DataArray.map_blocks Coordinates objects =================== diff --git a/doc/dask.rst b/doc/dask.rst index 07b3939af6e..727daacb938 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -274,12 +274,21 @@ loaded into Dask or not: .. _dask.automatic-parallelization: -Automatic parallelization -------------------------- +Automatic parallelization with ``apply_ufunc`` and ``map_blocks`` +----------------------------------------------------------------- Almost all of xarray's built-in operations work on Dask arrays. If you want to -use a function that isn't wrapped by xarray, one option is to extract Dask -arrays from xarray objects (``.data``) and use Dask directly. +use a function that isn't wrapped by xarray, and have it applied in parallel on +each block of your xarray object, you have three options: + +1. One option is to extract Dask arrays from xarray objects (``.data``) and use Dask directly. +2. Use :py:func:`~xarray.apply_ufunc` to apply functions that consume and return NumPy arrays. +3. Use :py:func:`~xarray.map_blocks`, :py:meth:`Dataset.map_blocks` or :py:meth:`DataArray.map_blocks` + to apply functions that consume and return xarray objects. + + +``apply_ufunc`` +~~~~~~~~~~~~~~~ Another option is to use xarray's :py:func:`~xarray.apply_ufunc`, which can automate `embarrassingly parallel @@ -382,6 +391,95 @@ application. structure of a problem, unlike the generic speedups offered by ``dask='parallelized'``. + +``map_blocks`` +~~~~~~~~~~~~~~ + +Functions that consume and return xarray objects can be easily applied in parallel using :py:func:`map_blocks`. Your function will receive an xarray Dataset or DataArray subset to one chunk +along each chunked dimension. + +.. ipython:: python + + ds.temperature + +This DataArray has 3 chunks each with length 10 along the time dimension. A function applied with :py:func:`map_blocks` will receive a DataArray corresponding to a single block of shape 10x180x180 +(time x latitude x longitude). The following snippet illustrates how to check the shape of the object +received by the applied function. + +.. ipython:: python + + def func(da): + print(da.sizes) + return da.time + + mapped = xr.map_blocks(func, ds.temperature) + mapped + +Notice that the :py:meth:`map_blocks` call printed +``Frozen({'time': 0, 'latitude': 0, 'longitude': 0})`` to screen. +``func`` is received 0-sized blocks! :py:meth:`map_blocks` needs to know what the final result +looks like in terms of dimensions, shapes etc. It does so by running the provided function on 0-shaped +inputs (*automated inference*). This works in many cases, but not all. If automatic inference does not +work for your function, provide the ``template`` kwarg (see below). + +In this case, automatic inference has worked so let's check that the result is as expected. + +.. ipython:: python + + mapped.compute(scheduler="single-threaded") + mapped.identical(ds.time) + +Note that we use ``.compute(scheduler="single-threaded")``. +This executes the Dask graph in `serial` using a for loop, but allows for printing to screen and other +debugging techniques. We can easily see that our function is receiving blocks of shape 10x180x180 and +the returned result is identical to ``ds.time`` as expected. + + +Here is a common example where automated inference will not work. + +.. ipython:: python + :okexcept: + + def func(da): + print(da.sizes) + return da.isel(time=[1]) + + mapped = xr.map_blocks(func, ds.temperature) + +``func`` cannot be run on 0-shaped inputs because it is not possible to extract element 1 along a +dimension of size 0. In this case we need to tell :py:func:`map_blocks` what the returned result looks +like using the ``template`` kwarg. ``template`` must be an xarray Dataset or DataArray (depending on +what the function returns) with dimensions, shapes, chunk sizes, coordinate variables *and* data +variables that look exactly like the expected result. The variables should be dask-backed and hence not +incur much memory cost. + +.. ipython:: python + + template = ds.temperature.isel(time=[1, 11, 21]) + mapped = xr.map_blocks(func, ds.temperature, template=template) + + +Notice that the 0-shaped sizes were not printed to screen. Since ``template`` has been provided +:py:func:`map_blocks` does not need to infer it by running ``func`` on 0-shaped inputs. + +.. ipython:: python + + mapped.identical(template) + + +:py:func:`map_blocks` also allows passing ``args`` and ``kwargs`` down to the user function ``func``. +``func`` will be executed as ``func(block_xarray, *args, **kwargs)`` so ``args`` must be a list and ``kwargs`` must be a dictionary. + +.. ipython:: python + + def func(obj, a, b=0): + return obj + a + b + + mapped = ds.map_blocks(func, args=[10], kwargs={"b": 10}) + expected = ds + 10 + 10 + mapped.identical(expected) + + Chunking and performance ------------------------ From da62535d73d3e06d27281cc194618c18ddc4d22c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Mar 2020 17:21:23 -0600 Subject: [PATCH 19/26] Update doc/dask.rst Co-Authored-By: Joe Hamman --- doc/dask.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/dask.rst b/doc/dask.rst index 727daacb938..f11edcdbcf0 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -281,7 +281,7 @@ Almost all of xarray's built-in operations work on Dask arrays. If you want to use a function that isn't wrapped by xarray, and have it applied in parallel on each block of your xarray object, you have three options: -1. One option is to extract Dask arrays from xarray objects (``.data``) and use Dask directly. +1. Extract Dask arrays from xarray objects (``.data``) and use Dask directly. 2. Use :py:func:`~xarray.apply_ufunc` to apply functions that consume and return NumPy arrays. 3. Use :py:func:`~xarray.map_blocks`, :py:meth:`Dataset.map_blocks` or :py:meth:`DataArray.map_blocks` to apply functions that consume and return xarray objects. From 0f741e2ccb3f518f6f15bf5a66bd03d27fb8f667 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 1 May 2020 10:44:51 -0600 Subject: [PATCH 20/26] refactor out slicer for chunks --- xarray/core/parallel.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 957ae569d87..97651e56caf 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -127,6 +127,18 @@ def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: return {k: v.data for k, v in x.variables.items()} +def _get_chunk_slicer( + dim: Hashable, input_chunk_index: Mapping, chunk_index_bounds: Mapping +): + if dim in input_chunk_index: + which_chunk = input_chunk_index[dim] + return slice( + chunk_index_bounds[dim][which_chunk], + chunk_index_bounds[dim][which_chunk + 1], + ) + return slice(None) + + def map_blocks( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], @@ -370,15 +382,10 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): else: # non-dask array with possibly chunked dimensions # index into variable appropriately - subsetter = {} - for dim in variable.dims: - if dim in input_chunk_index: - which_chunk = input_chunk_index[dim] - subsetter[dim] = slice( - chunk_index_bounds[dim][which_chunk], - chunk_index_bounds[dim][which_chunk + 1], - ) - + subsetter = { + dim: _get_chunk_slicer(dim, input_chunk_index, chunk_index_bounds) + for dim in variable.dims + } subset = variable.isel(subsetter) chunk_variable_task = ( "{}-{}".format(gname, dask.base.tokenize(subset)), From 869e62db89a0c1288f323bef8a026b54c6e5d6e1 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 1 May 2020 11:14:34 -0600 Subject: [PATCH 21/26] Check expected index values. --- xarray/core/parallel.py | 53 +++++++++++++++++++++++---------------- xarray/tests/test_dask.py | 12 +++++++++ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 97651e56caf..44b3d692386 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -32,7 +32,7 @@ def check_result_variables( - result: Union[DataArray, Dataset], expected: dict, kind: str + result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str ): if kind == "coords": @@ -127,15 +127,10 @@ def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: return {k: v.data for k, v in x.variables.items()} -def _get_chunk_slicer( - dim: Hashable, input_chunk_index: Mapping, chunk_index_bounds: Mapping -): - if dim in input_chunk_index: - which_chunk = input_chunk_index[dim] - return slice( - chunk_index_bounds[dim][which_chunk], - chunk_index_bounds[dim][which_chunk + 1], - ) +def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): + if dim in chunk_index: + which_chunk = chunk_index[dim] + return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1]) return slice(None) @@ -259,20 +254,31 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): f"Dimensions {missing_dimensions} missing on returned object." ) - # check that index lengths are as expected + # check that index lengths and values are as expected for name, index in result.indexes.items(): if name in check_shapes: if len(index) != check_shapes[name]: raise ValueError( f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." ) + if name in expected["indexes"]: + expected_index = expected["indexes"][name] + if not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) + # check that all expected variables were returned check_result_variables(result, expected, "coords") if isinstance(result, Dataset): check_result_variables(result, expected, "data_vars") return make_dict(result) + if template is not None and not isinstance(template, (DataArray, Dataset)): + raise TypeError( + f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." + ) if not isinstance(args, Sequence): raise TypeError("args must be a sequence (for example, a list or tuple).") if kwargs is None: @@ -351,13 +357,16 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # map dims to list of chunk indexes ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} # mapping from chunk index to slice bounds - chunk_index_bounds = { + input_chunk_bounds = { dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() } + output_chunk_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() + } # iterate over all possible chunk combinations for v in itertools.product(*ichunk.values()): - input_chunk_index = dict(zip(dataset.dims, v)) + chunk_index = dict(zip(dataset.dims, v)) # this will become [[name1, variable1], # [name2, variable2], @@ -372,7 +381,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # recursively index into dask_keys nested list to get chunk chunk = variable.__dask_keys__() for dim in variable.dims: - chunk = chunk[input_chunk_index[dim]] + chunk = chunk[chunk_index[dim]] chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v graph[chunk_variable_task] = ( @@ -383,7 +392,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # non-dask array with possibly chunked dimensions # index into variable appropriately subsetter = { - dim: _get_chunk_slicer(dim, input_chunk_index, chunk_index_bounds) + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) for dim in variable.dims } subset = variable.isel(subsetter) @@ -401,17 +410,19 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): else: data_vars.append([name, chunk_variable_task]) - # expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper + # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper expected = {} # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function expected["shapes"] = { - k: output_chunks[k][v] - for k, v in input_chunk_index.items() - if k in output_chunks + k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore expected["coords"] = set(template.coords.keys()) # type: ignore + expected["indexes"] = { + dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + for dim in indexes + } from_wrapper = (gname,) + v graph[from_wrapper] = ( @@ -434,8 +445,8 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): key: Tuple[Any, ...] = (gname_l,) for dim in variable.dims: - if dim in input_chunk_index: - key += (input_chunk_index[dim],) + if dim in chunk_index: + key += (chunk_index[dim],) elif dim in output_chunks: raise ValueError( f"Function is attempting to add a new chunked dimension {dim}. This is not allowed." diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 10230fc8dee..7b2370e9b99 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1203,6 +1203,18 @@ def test_map_blocks_errors_bad_template(obj): xr.map_blocks(lambda x: x.isel(x=1), obj, template=obj).compute() with raises_regex(ValueError, "Received dimension 'x' of length 1"): xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute() + with raises_regex(TypeError, "must be a DataArray"): + xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=(obj,)).compute() + with raises_regex(ValueError, "map_blocks requires that one block"): + xr.map_blocks( + lambda x: x.isel(x=[1]).assign_coords(x=10), obj, template=obj.isel(x=[1]) + ).compute() + with raises_regex(ValueError, "Expected index 'x' to be"): + xr.map_blocks( + lambda a: a.isel(x=[1]).assign_coords(x=[120]), # assign bad index values + obj, + template=obj.isel(x=[1, 5, 9]), + ).compute() def test_map_blocks_errors_bad_template_2(map_ds): From 334e8d32b1c695e3ae73021c3adbc17c342342de Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 1 May 2020 11:14:53 -0600 Subject: [PATCH 22/26] Raise nice error when template object does not have required number of chunks --- xarray/core/parallel.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 44b3d692386..608a31e4b4e 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -333,6 +333,14 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): else: output_chunks = template.chunks # type: ignore + for dim in output_chunks: + if len(input_chunks[dim]) != len(output_chunks[dim]): + raise ValueError( + "map_blocks requires that one block of the input maps to one block of output. " + f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. " + f"Received {len(output_chunks[dim])} instead. Please provide template (if not provided), or " + "fix the provided template." + ) if isinstance(template, DataArray): result_is_array = True template_name = template.name From 4dd699ed85aebb8e9fcd3773f7ee9a348ea46333 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 1 May 2020 11:15:13 -0600 Subject: [PATCH 23/26] doc updates. --- doc/dask.rst | 12 +++++++----- doc/whats-new.rst | 5 +++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/doc/dask.rst b/doc/dask.rst index f9d6769d856..fd1be6f8204 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -413,15 +413,17 @@ application. ``map_blocks`` ~~~~~~~~~~~~~~ -Functions that consume and return xarray objects can be easily applied in parallel using :py:func:`map_blocks`. Your function will receive an xarray Dataset or DataArray subset to one chunk +Functions that consume and return xarray objects can be easily applied in parallel using :py:func:`map_blocks`. +Your function will receive an xarray Dataset or DataArray subset to one chunk along each chunked dimension. .. ipython:: python ds.temperature -This DataArray has 3 chunks each with length 10 along the time dimension. A function applied with :py:func:`map_blocks` will receive a DataArray corresponding to a single block of shape 10x180x180 -(time x latitude x longitude). The following snippet illustrates how to check the shape of the object +This DataArray has 3 chunks each with length 10 along the time dimension. +At compute time, a function applied with :py:func:`map_blocks` will receive a DataArray corresponding to a single block of shape 10x180x180 +(time x latitude x longitude) with values loaded. The following snippet illustrates how to check the shape of the object received by the applied function. .. ipython:: python @@ -444,10 +446,10 @@ In this case, automatic inference has worked so let's check that the result is a .. ipython:: python - mapped.compute(scheduler="single-threaded") + mapped.load(scheduler="single-threaded") mapped.identical(ds.time) -Note that we use ``.compute(scheduler="single-threaded")``. +Note that we use ``.load(scheduler="single-threaded")`` to execute the computation. This executes the Dask graph in `serial` using a for loop, but allows for printing to screen and other debugging techniques. We can easily see that our function is receiving blocks of shape 10x180x180 and the returned result is identical to ``ds.time`` as expected. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 051a41a57e5..29f10d4c87f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,9 @@ New Features the :py:class:`~core.accessor_dt.DatetimeAccessor` (:pull:`3935`). This feature requires cftime version 1.1.0 or greater. By `Spencer Clark `_. +- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases + where the result of a computation could not be inferred automatically. + By `Deepak Cherian `_ Bug fixes ~~~~~~~~~ @@ -115,6 +118,8 @@ Documentation By `Matthias Riße `_. - Apply ``black`` to all the code in the documentation (:pull:`4012`) By `Justus Magin `_. +- Narrative documentation now describes :py:meth:`map_blocks`. :ref:`dask.automatic-parallelization`. + By `Deepak Cherian `_. Internal Changes ~~~~~~~~~~~~~~~~ From 27eb873a2c519954ae92c9b4da5d2dee98057c6a Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 1 May 2020 15:07:10 -0600 Subject: [PATCH 24/26] more review comments. --- xarray/core/parallel.py | 10 ++++------ xarray/tests/test_dask.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 608a31e4b4e..cc7a15feee7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -334,13 +334,14 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): output_chunks = template.chunks # type: ignore for dim in output_chunks: - if len(input_chunks[dim]) != len(output_chunks[dim]): + if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): raise ValueError( "map_blocks requires that one block of the input maps to one block of output. " f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. " - f"Received {len(output_chunks[dim])} instead. Please provide template (if not provided), or " + f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or " "fix the provided template." ) + if isinstance(template, DataArray): result_is_array = True template_name = template.name @@ -455,12 +456,9 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): for dim in variable.dims: if dim in chunk_index: key += (chunk_index[dim],) - elif dim in output_chunks: - raise ValueError( - f"Function is attempting to add a new chunked dimension {dim}. This is not allowed." - ) else: # unchunked dimensions in the input have one chunk in the result + # output can have new dimensions with exactly one chunk key += (0,) # We're adding multiple new layers to the graph: diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 7b2370e9b99..a2f9374e5df 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1193,6 +1193,22 @@ def test_map_blocks_da_ds_with_template(obj): assert_identical(actual, template) +def test_map_blocks_template_convert_object(): + da = make_da() + func = lambda x: x.to_dataset().isel(x=[1]) + template = da.to_dataset().isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, da, template=template) + assert_identical(actual, template) + + ds = da.to_dataset() + func = lambda x: x.to_array().isel(x=[1]) + template = ds.to_array().isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, ds, template=template) + assert_identical(actual, template) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_errors_bad_template(obj): with raises_regex(ValueError, "unexpected coordinate variables"): From 04a2c3c1c16d332233602e3718248902c767b6ea Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 1 May 2020 15:26:31 -0600 Subject: [PATCH 25/26] Mention that attrs are taken from template. --- doc/dask.rst | 2 +- xarray/core/parallel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/dask.rst b/doc/dask.rst index fd1be6f8204..217df7b5ba7 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -469,7 +469,7 @@ Here is a common example where automated inference will not work. ``func`` cannot be run on 0-shaped inputs because it is not possible to extract element 1 along a dimension of size 0. In this case we need to tell :py:func:`map_blocks` what the returned result looks like using the ``template`` kwarg. ``template`` must be an xarray Dataset or DataArray (depending on -what the function returns) with dimensions, shapes, chunk sizes, coordinate variables *and* data +what the function returns) with dimensions, shapes, chunk sizes, attributes, coordinate variables *and* data variables that look exactly like the expected result. The variables should be dask-backed and hence not incur much memory cost. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index cc7a15feee7..2bf01076422 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -170,7 +170,7 @@ def map_blocks( xarray object representing the final result after compute is called. If not provided, the function will be first run on mocked-up data, that looks like 'obj' but has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). + variable names, attributes, new dimensions and new indexes (if any). 'template' must be provided if the function changes the size of existing dimensions. Returns From 4b92168ae71a1d34e63888f654b794dee4296377 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 5 May 2020 12:03:31 -0600 Subject: [PATCH 26/26] Add test and explicit point out that attrs is copied from template --- doc/dask.rst | 6 ++++++ xarray/core/parallel.py | 3 +++ xarray/tests/test_dask.py | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/doc/dask.rst b/doc/dask.rst index 217df7b5ba7..df223982ba4 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -473,6 +473,12 @@ what the function returns) with dimensions, shapes, chunk sizes, attributes, coo variables that look exactly like the expected result. The variables should be dask-backed and hence not incur much memory cost. +.. note:: + + Note that when ``template`` is provided, ``attrs`` from ``template`` are copied over to the result. Any + ``attrs`` set in ``func`` will be ignored. + + .. ipython:: python template = ds.temperature.isel(time=[1, 11, 21]) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2bf01076422..d91dfb4a275 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -172,6 +172,9 @@ def map_blocks( has sizes 0, to determine properties of the returned object such as dtype, variable names, attributes, new dimensions and new indexes (if any). 'template' must be provided if the function changes the size of existing dimensions. + When provided, `attrs` on variables in `template` are copied over to the result. Any + `attrs` set by `func` will be ignored. + Returns ------- diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a2f9374e5df..75beb3757ca 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1109,6 +1109,11 @@ def add_attrs(obj): assert_identical(actual, expected) + # when template is specified, attrs are copied from template, not set by function + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj, template=obj) + assert_identical(actual, obj) + def test_map_blocks_change_name(map_da): def change_name(obj):