From b6409f0627d813065b58f67e6244cbe47f84090c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Mar 2020 19:51:06 +0000 Subject: [PATCH] map_blocks: allow user function to add new unindexed dimension. (#3817) --- doc/whats-new.rst | 3 ++- xarray/core/parallel.py | 3 +++ xarray/tests/test_dask.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ac80524a3c4..86272cf8710 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -43,6 +43,8 @@ New Features arguments and should instead pass a single list of dimensions. (:pull:`3802`) By `Maximilian Roos `_ +- :py:func:`map_blocks` can now apply functions that add new unindexed dimensions. + By `Deepak Cherian `_ - The new ``Dataset._repr_html_`` and ``DataArray._repr_html_`` (introduced in 0.14.1) is now on by default. To disable, use ``xarray.set_options(display_style="text")``. @@ -60,7 +62,6 @@ New Features (:issue:`3843`, :pull:`3844`) By `Aaron Spring `_. - Bug fixes ~~~~~~~~~ - Fix :py:meth:`Dataset.interp` when indexing array shares coordinates with the diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8429d0f71ad..6f1668f698f 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -386,6 +386,9 @@ def _wrapper(func, obj, to_array, args, kwargs): var_chunks.append(input_chunks[dim]) elif dim in indexes: var_chunks.append((len(indexes[dim]),)) + elif dim in template.dims: + # new unindexed dimension + var_chunks.append((template.sizes[dim],)) data = dask.array.Array( hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 4f7e3910f82..923b35e5946 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1147,6 +1147,7 @@ def test_map_blocks_to_array(map_ds): lambda x: x.to_dataset(), lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.expand_dims(k=3), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), # TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da], @@ -1167,6 +1168,7 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.drop_vars("a"), lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.expand_dims(k=3), lambda x: x.rename({"a": "new1", "b": "new2"}), # TODO: [lambda x: x.isel(x=1)], ],