Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Fix map_blocks HLG layering (pydata#3598)
Browse files Browse the repository at this point in the history
* Fix map_blocks HLG layering

This fixes an issue with the HighLevelGraph noted in
pydata#3584, and exposed by a recent
change in Dask to do more HLG fusion.

* update

* black

* update
  • Loading branch information
TomAugspurger authored and dcherian committed Dec 7, 2019
1 parent 4c51aa2 commit cafcaee
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Bug fixes
~~~~~~~~~
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix issue with Dask-backed datasets raising a ``KeyError`` on some computations involving ``map_blocks`` (:pull:`3598`)
By `Tom Augspurger <https://github.com/TomAugspurger>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
24 changes: 21 additions & 3 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
except ImportError:
pass

import collections
import itertools
import operator
from typing import (
Any,
Callable,
Dict,
DefaultDict,
Hashable,
Mapping,
Sequence,
Expand Down Expand Up @@ -221,7 +223,12 @@ def _wrapper(func, obj, to_array, args, kwargs):
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
indexes.update({k: template.indexes[k] for k in new_indexes})

# We're building a new HighLevelGraph hlg. We'll have one new layer
# for each variable in the dataset, which is the result of the
# func applied to the values.

graph: Dict[Any, Any] = {}
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
gname = "{}-{}".format(
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
)
Expand Down Expand Up @@ -310,9 +317,20 @@ def _wrapper(func, obj, to_array, args, kwargs):
# unchunked dimensions in the input have one chunk in the result
key += (0,)

graph[key] = (operator.getitem, from_wrapper, name)
# We're adding multiple new layers to the graph:
# The first new layer is the result of the computation on
# the array.
# Then we add one layer per variable, which extracts the
# result for that variable, and depends on just the first new
# layer.
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)

hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])

graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
for gname_l, layer in new_layers.items():
# This adds in the getitems for each variable in the dataset.
hlg.dependencies[gname_l] = {gname}
hlg.layers[gname_l] = layer

result = Dataset(coords=indexes, attrs=template.attrs)
for name, gname_l in var_key_map.items():
Expand All @@ -325,7 +343,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
var_chunks.append((len(indexes[dim]),))

data = dask.array.Array(
graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
)
result[name] = (dims, data, template[name].attrs)

Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,19 @@ def func(obj):
assert_identical(expected.compute(), actual.compute())


def test_map_blocks_hlg_layers():
# regression test for #3599
ds = xr.Dataset(
{
"x": (("a",), dask.array.ones(10, chunks=(5,))),
"z": (("b",), dask.array.ones(10, chunks=(5,))),
}
)
mapped = ds.map_blocks(lambda x: x)

xr.testing.assert_equal(mapped, ds)


def test_make_meta(map_ds):
from ..core.parallel import make_meta

Expand Down

0 comments on commit cafcaee

Please sign in to comment.