Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_blocks #3276

Merged
merged 82 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
b090a9e
map_block attempt 2
dcherian Sep 2, 2019
3948798
Address reviews: errors, args + kwargs support.
dcherian Sep 2, 2019
4f159c8
Works with datasets!
dcherian Sep 3, 2019
9179f0b
remove wrong comment.
dcherian Sep 8, 2019
20c5d5b
Support chunks.
dcherian Sep 8, 2019
b16b237
infer template.
dcherian Sep 8, 2019
43ef2b7
cleanup
dcherian Sep 8, 2019
5ebf738
cleanup2
dcherian Sep 8, 2019
8a460bb
api.rst
dcherian Sep 8, 2019
505f3f0
simple shape change error check.
dcherian Sep 8, 2019
fe1982f
Make test more complicated.
dcherian Sep 8, 2019
066eb59
Fix for when user function doesn't set DataArray.name
dcherian Sep 9, 2019
83eb310
Now _to_temp_dataset works.
dcherian Sep 9, 2019
008ce29
Add whats-new
dcherian Sep 9, 2019
adbe48e
chunks kwarg makes no sense right now.
dcherian Sep 9, 2019
924bf69
review feedback:
dcherian Sep 19, 2019
8aed8e7
Support nondim coords in make_meta.
dcherian Sep 19, 2019
d0797f6
Add Dataset.unify_chunks
dcherian Sep 19, 2019
599b70a
Merge branch 'master' into map_blocks_2
dcherian Sep 19, 2019
765ca5d
doc updates.
dcherian Sep 19, 2019
180bbf2
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Sep 19, 2019
f0de1db
minor.
dcherian Sep 19, 2019
1251a5d
update comment.
dcherian Sep 19, 2019
47a0e39
More complicated test dataset. Tests fail :X
dcherian Sep 19, 2019
fa44d32
Don't know why compute is needed.
dcherian Sep 19, 2019
a6e84ef
work with DataArray nondim coords.
dcherian Sep 19, 2019
c28b402
fastpath unify_chunks
dcherian Sep 20, 2019
1694d03
comment.
dcherian Sep 20, 2019
cf04ec8
much improved tests.
dcherian Sep 20, 2019
3e9db26
Change args, kwargs syntax.
dcherian Sep 20, 2019
20fdde6
Add dataset, dataarray methods.
dcherian Sep 20, 2019
22e9c4e
api.rst
dcherian Sep 20, 2019
b145787
docstrings.
dcherian Sep 20, 2019
f600c4a
Fix unify_chunks.
dcherian Sep 23, 2019
4af5a67
Move assert_chunks_equal to xarray.testing.
dcherian Sep 23, 2019
3ca4b7b
minor changes.
dcherian Sep 23, 2019
3345d25
Better error handling when inferring returned object
dcherian Sep 23, 2019
54c77dd
wip
dcherian Sep 23, 2019
fb1ff0b
Docstrings + nicer error message.
dcherian Sep 26, 2019
bad0855
wip
dcherian Sep 23, 2019
291e6e6
better to_array
dcherian Sep 23, 2019
b31537c
remove unify_chunks in map_blocks + better tests.
dcherian Sep 26, 2019
72e7913
typing for unify_chunks
dcherian Sep 26, 2019
0a6bbed
address more review comments.
dcherian Sep 26, 2019
210987e
more unify_chunks tests.
dcherian Sep 26, 2019
582e0d5
Just use dask.core.utils.meta_from_array
dcherian Sep 26, 2019
d0fd87e
get tests working. assert_equal needs a lot of fixing.
dcherian Sep 28, 2019
875264a
more unify_chunks test.
dcherian Sep 28, 2019
0f03e37
assert_chunks_equal fixes.
dcherian Sep 28, 2019
8175d73
copy over meta_from_array.
dcherian Sep 28, 2019
6ab8737
minor fixes.
dcherian Sep 28, 2019
08c41b9
raise chunks error earlier and test for map_blocks raising chunk error
dcherian Sep 28, 2019
76bc23c
fix.
dcherian Sep 28, 2019
49d3899
Type annotations
Oct 1, 2019
ae53b85
py35 compat
Oct 1, 2019
f6dfb12
make sure unify_chunks does not compute.
dcherian Oct 1, 2019
c73eda1
Make tests functional by call compute before assert_equal
dcherian Oct 1, 2019
8ad882b
Update whats-new
dcherian Oct 3, 2019
aa4ea00
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Oct 3, 2019
3cda5ac
Work with attributes.
dcherian Oct 3, 2019
49969a7
Support attrs and name changes.
dcherian Oct 3, 2019
6faf79e
more assert_equal
dcherian Oct 3, 2019
47baf76
test changing coord attribute
dcherian Oct 3, 2019
1295499
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Oct 3, 2019
ce252f2
fix whats new
dcherian Oct 3, 2019
50ae13f
rework tests to use fixtures (kind of)
dcherian Oct 7, 2019
cdcf221
more review changes.
dcherian Oct 7, 2019
f167537
cleanup
dcherian Oct 7, 2019
4390f73
more review feedback.
dcherian Oct 7, 2019
c936557
fix unify_chunks.
dcherian Oct 8, 2019
e34aafe
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Oct 9, 2019
2c7938a
read dask_array_compat :)
dcherian Oct 9, 2019
08ed873
Dask 1.2.0 compat.
dcherian Oct 10, 2019
67663aa
Merge remote-tracking branch 'upstream/master' into map_blocks_2
crusaderky Oct 10, 2019
99d61fc
documentation polish
crusaderky Oct 10, 2019
687689e
make_meta reflow
crusaderky Oct 10, 2019
f588cb6
cosmetic
crusaderky Oct 10, 2019
d476e2f
polish
crusaderky Oct 10, 2019
26a6a0d
Fix tests
crusaderky Oct 10, 2019
6491753
isort
crusaderky Oct 10, 2019
b227bea
isort
crusaderky Oct 10, 2019
2a41906
Add func call to docstrings.
dcherian Oct 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .core.dataarray import DataArray
from .core.merge import merge, MergeError
from .core.options import set_options
from .core.parallel import map_blocks

from .backends.api import (
open_dataset,
Expand Down
141 changes: 141 additions & 0 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
try:
import dask
import dask.array
from dask.highlevelgraph import HighLevelGraph

except ImportError:
pass

import itertools
import numpy as np

from .dataarray import DataArray
from .dataset import Dataset


def map_blocks(func, obj, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotations please

"""
Apply a function to each chunk of a DataArray or Dataset.

Parameters
----------
func: callable
User-provided function that should accept DataArrays corresponding to one chunk.
obj: DataArray, Dataset
Chunks of this object will be provided to 'func'. The function must not change
shape of the provided DataArray.
args, kwargs:
Passed on to func.

Returns
-------
DataArray

See Also
--------
dask.array.map_blocks
"""

def _wrapper(func, obj, to_array, args, kwargs):
if to_array:
# this should be easier
obj = obj.to_array().squeeze().drop("variable")

result = func(obj, *args, **kwargs)

if not isinstance(result, type(obj)):
raise ValueError("Result is not the same type as input.")
if result.shape != obj.shape:
raise ValueError("Result does not have the same shape as input.")

return result
dcherian marked this conversation as resolved.
Show resolved Hide resolved

# if not isinstance(obj, DataArray):
# raise ValueError("map_blocks can only be used with DataArrays at present.")

if isinstance(obj, DataArray):
dataset = obj._to_temp_dataset()
to_array = True
else:
dataset = obj
to_array = False

dataset_dims = list(dataset.dims)

graph = {}
gname = "map-%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset))

# map dims to list of chunk indexes
# If two different variables have different chunking along the same dim
# .chunks will raise an error.
chunks = dataset.chunks
ichunk = {dim: range(len(chunks[dim])) for dim in chunks}
# mapping from chunk index to slice bounds
chunk_index_bounds = {dim: np.cumsum((0,) + chunks[dim]) for dim in chunks}

# iterate over all possible chunk combinations
for v in itertools.product(*ichunk.values()):
chunk_index_dict = dict(zip(dataset_dims, v))

# this will become [[name1, variable1],
# [name2, variable2],
# ...]
# which is passed to dict and then to Dataset
data_vars = []
coords = []

for name, variable in dataset.variables.items():
# make a task that creates tuple of (dims, chunk)
if dask.is_dask_collection(variable.data):
var_dask_keys = variable.__dask_keys__()

# recursively index into dask_keys nested list to get chunk
chunk = var_dask_keys
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for dim in variable.dims:
chunk = chunk[chunk_index_dict[dim]]

task_name = ("tuple-" + dask.base.tokenize(chunk),) + v
dcherian marked this conversation as resolved.
Show resolved Hide resolved
graph[task_name] = (tuple, [variable.dims, chunk])
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
# numpy array with possibly chunked dimensions
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# index into variable appropriately
subsetter = dict()
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for dim in variable.dims:
if dim in chunk_index_dict:
which_chunk = chunk_index_dict[dim]
subsetter[dim] = slice(
chunk_index_bounds[dim][which_chunk],
chunk_index_bounds[dim][which_chunk + 1],
)

subset = variable.isel(subsetter)
task_name = (name + dask.base.tokenize(subset),) + v
graph[task_name] = (tuple, [subset.dims, subset])

# this task creates dict mapping variable name to above tuple
if name in dataset.data_vars:
data_vars.append([name, task_name])
if name in dataset.coords:
coords.append([name, task_name])
dcherian marked this conversation as resolved.
Show resolved Hide resolved

graph[(gname,) + v] = (
_wrapper,
func,
(Dataset, (dict, data_vars), (dict, coords), dataset.attrs),
to_array,
args,
kwargs,
)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

final_graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset])

if isinstance(obj, DataArray):
result = DataArray(
dask.array.Array(
final_graph, name=gname, chunks=obj.data.chunks, meta=obj.data._meta
),
dims=obj.dims,
coords=obj.coords,
)

return result
28 changes: 28 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,31 @@ def test_dask_layers_and_dependencies():
assert set(x.foo.__dask_graph__().dependencies).issuperset(
ds.__dask_graph__().dependencies
)


def test_map_blocks():
darray = xr.DataArray(
dask.array.ones((10, 20), chunks=[4, 5]),
dims=["x", "y"],
coords={"x": np.arange(10), "y": np.arange(100, 120)},
dcherian marked this conversation as resolved.
Show resolved Hide resolved
)
darray.name = None

def good_func(darray):
return darray * darray.x + 5 * darray.y

def bad_func(darray):
return (darray * darray.x + 5 * darray.y)[:1, :1]

actual = xr.map_blocks(good_func, darray)
expected = good_func(darray)
xr.testing.assert_equal(expected, actual)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

with raises_regex(ValueError, "not have the same shape"):
xr.map_blocks(bad_func, darray).compute()

import operator

expected = darray + 10
actual = xr.map_blocks(operator.add, darray, 10)
xr.testing.assert_equal(expected, actual)
dcherian marked this conversation as resolved.
Show resolved Hide resolved