Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_blocks #3276

Merged
merged 82 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
b090a9e
map_block attempt 2
dcherian Sep 2, 2019
3948798
Address reviews: errors, args + kwargs support.
dcherian Sep 2, 2019
4f159c8
Works with datasets!
dcherian Sep 3, 2019
9179f0b
remove wrong comment.
dcherian Sep 8, 2019
20c5d5b
Support chunks.
dcherian Sep 8, 2019
b16b237
infer template.
dcherian Sep 8, 2019
43ef2b7
cleanup
dcherian Sep 8, 2019
5ebf738
cleanup2
dcherian Sep 8, 2019
8a460bb
api.rst
dcherian Sep 8, 2019
505f3f0
simple shape change error check.
dcherian Sep 8, 2019
fe1982f
Make test more complicated.
dcherian Sep 8, 2019
066eb59
Fix for when user function doesn't set DataArray.name
dcherian Sep 9, 2019
83eb310
Now _to_temp_dataset works.
dcherian Sep 9, 2019
008ce29
Add whats-new
dcherian Sep 9, 2019
adbe48e
chunks kwarg makes no sense right now.
dcherian Sep 9, 2019
924bf69
review feedback:
dcherian Sep 19, 2019
8aed8e7
Support nondim coords in make_meta.
dcherian Sep 19, 2019
d0797f6
Add Dataset.unify_chunks
dcherian Sep 19, 2019
599b70a
Merge branch 'master' into map_blocks_2
dcherian Sep 19, 2019
765ca5d
doc updates.
dcherian Sep 19, 2019
180bbf2
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Sep 19, 2019
f0de1db
minor.
dcherian Sep 19, 2019
1251a5d
update comment.
dcherian Sep 19, 2019
47a0e39
More complicated test dataset. Tests fail :X
dcherian Sep 19, 2019
fa44d32
Don't know why compute is needed.
dcherian Sep 19, 2019
a6e84ef
work with DataArray nondim coords.
dcherian Sep 19, 2019
c28b402
fastpath unify_chunks
dcherian Sep 20, 2019
1694d03
comment.
dcherian Sep 20, 2019
cf04ec8
much improved tests.
dcherian Sep 20, 2019
3e9db26
Change args, kwargs syntax.
dcherian Sep 20, 2019
20fdde6
Add dataset, dataarray methods.
dcherian Sep 20, 2019
22e9c4e
api.rst
dcherian Sep 20, 2019
b145787
docstrings.
dcherian Sep 20, 2019
f600c4a
Fix unify_chunks.
dcherian Sep 23, 2019
4af5a67
Move assert_chunks_equal to xarray.testing.
dcherian Sep 23, 2019
3ca4b7b
minor changes.
dcherian Sep 23, 2019
3345d25
Better error handling when inferring returned object
dcherian Sep 23, 2019
54c77dd
wip
dcherian Sep 23, 2019
fb1ff0b
Docstrings + nicer error message.
dcherian Sep 26, 2019
bad0855
wip
dcherian Sep 23, 2019
291e6e6
better to_array
dcherian Sep 23, 2019
b31537c
remove unify_chunks in map_blocks + better tests.
dcherian Sep 26, 2019
72e7913
typing for unify_chunks
dcherian Sep 26, 2019
0a6bbed
address more review comments.
dcherian Sep 26, 2019
210987e
more unify_chunks tests.
dcherian Sep 26, 2019
582e0d5
Just use dask.core.utils.meta_from_array
dcherian Sep 26, 2019
d0fd87e
get tests working. assert_equal needs a lot of fixing.
dcherian Sep 28, 2019
875264a
more unify_chunks test.
dcherian Sep 28, 2019
0f03e37
assert_chunks_equal fixes.
dcherian Sep 28, 2019
8175d73
copy over meta_from_array.
dcherian Sep 28, 2019
6ab8737
minor fixes.
dcherian Sep 28, 2019
08c41b9
raise chunks error earlier and test for map_blocks raising chunk error
dcherian Sep 28, 2019
76bc23c
fix.
dcherian Sep 28, 2019
49d3899
Type annotations
Oct 1, 2019
ae53b85
py35 compat
Oct 1, 2019
f6dfb12
make sure unify_chunks does not compute.
dcherian Oct 1, 2019
c73eda1
Make tests functional by call compute before assert_equal
dcherian Oct 1, 2019
8ad882b
Update whats-new
dcherian Oct 3, 2019
aa4ea00
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Oct 3, 2019
3cda5ac
Work with attributes.
dcherian Oct 3, 2019
49969a7
Support attrs and name changes.
dcherian Oct 3, 2019
6faf79e
more assert_equal
dcherian Oct 3, 2019
47baf76
test changing coord attribute
dcherian Oct 3, 2019
1295499
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Oct 3, 2019
ce252f2
fix whats new
dcherian Oct 3, 2019
50ae13f
rework tests to use fixtures (kind of)
dcherian Oct 7, 2019
cdcf221
more review changes.
dcherian Oct 7, 2019
f167537
cleanup
dcherian Oct 7, 2019
4390f73
more review feedback.
dcherian Oct 7, 2019
c936557
fix unify_chunks.
dcherian Oct 8, 2019
e34aafe
Merge remote-tracking branch 'upstream/master' into map_blocks_2
dcherian Oct 9, 2019
2c7938a
read dask_array_compat :)
dcherian Oct 9, 2019
08ed873
Dask 1.2.0 compat.
dcherian Oct 10, 2019
67663aa
Merge remote-tracking branch 'upstream/master' into map_blocks_2
crusaderky Oct 10, 2019
99d61fc
documentation polish
crusaderky Oct 10, 2019
687689e
make_meta reflow
crusaderky Oct 10, 2019
f588cb6
cosmetic
crusaderky Oct 10, 2019
d476e2f
polish
crusaderky Oct 10, 2019
26a6a0d
Fix tests
crusaderky Oct 10, 2019
6491753
isort
crusaderky Oct 10, 2019
b227bea
isort
crusaderky Oct 10, 2019
2a41906
Add func call to docstrings.
dcherian Oct 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Top-level functions
zeros_like
ones_like
dot
map_blocks

Dataset
=======
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ New functions/methods
This requires `sparse>=0.8.0`. By `Nezar Abdennur <https://github.com/nvictus>`_
and `Guido Imperiale <https://github.com/crusaderky>`_.

- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks`
By `Deepak Cherian <https://github.com/dcherian>`_.

- :py:meth:`~Dataset.from_dataframe` and :py:meth:`~DataArray.from_series` now
support ``sparse=True`` for converting pandas objects into xarray objects
wrapping sparse arrays. This is particularly useful with sparsely populated
Expand Down
1 change: 1 addition & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .core.dataarray import DataArray
from .core.merge import merge, MergeError
from .core.options import set_options
from .core.parallel import map_blocks

from .backends.api import (
open_dataset,
Expand Down
275 changes: 275 additions & 0 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
try:
import dask
import dask.array
from dask.highlevelgraph import HighLevelGraph

except ImportError:
pass

import itertools
import numpy as np
import operator

from .dataarray import DataArray
from .dataset import Dataset


def _to_array(obj):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(obj, Dataset):
raise ValueError("Trying to convert DataArray to DataArray!")
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if len(obj.data_vars) > 1:
raise ValueError(
"Trying to convert Dataset with more than one variable to DataArray"
dcherian marked this conversation as resolved.
Show resolved Hide resolved
)

name = list(obj.data_vars)[0]
# this should be easier
da = obj.to_array().squeeze().drop("variable")
da.name = name
return da
dcherian marked this conversation as resolved.
Show resolved Hide resolved


def make_meta(obj):

from dask.array.utils import meta_from_array
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(obj, DataArray):
meta = DataArray(obj.data._meta, dims=obj.dims)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(obj, Dataset):
meta = Dataset()
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for name, variable in obj.variables.items():
if dask.is_dask_collection(variable):
meta_obj = obj[name].data._meta
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
meta_obj = meta_from_array(obj[name].data)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
meta[name] = DataArray(meta_obj, dims=obj[name].dims)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
meta = obj

return meta


def infer_template(func, obj, *args, **kwargs):
""" Infer return object by running the function on meta objects. """
meta_args = []
for arg in (obj,) + args:
meta_args.append(make_meta(arg))
dcherian marked this conversation as resolved.
Show resolved Hide resolved

try:
template = func(*meta_args, **kwargs)
except ValueError:
raise ValueError("Cannot infer object returned by user-provided function.")
dcherian marked this conversation as resolved.
Show resolved Hide resolved

return template


def _make_dict(x):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# Dataset.to_dict() is too complicated
# maps variable name to numpy array
if isinstance(x, DataArray):
x = x._to_temp_dataset()

to_return = dict()
for var in x.variables:
# if var not in x:
# raise ValueError("Variable %r not found in returned object." % var)
to_return[var] = x[var].values

return to_return
dcherian marked this conversation as resolved.
Show resolved Hide resolved


def map_blocks(func, obj, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotations please

"""
Apply a function to each chunk of a DataArray or Dataset. This function is experimental
and its signature may change.

Parameters
----------
func: callable
User-provided function that should accept DataArrays corresponding to one chunk.
The function will be run on a small piece of data that looks like 'obj' to determine
properties of the returned object such as dtype, variable names,
new dimensions and new indexes (if any).

This function cannot
- change size of existing dimensions.
- add new chunked dimensions.
dcherian marked this conversation as resolved.
Show resolved Hide resolved

obj: DataArray, Dataset
Chunks of this object will be provided to 'func'. The function must not change
shape of the provided DataArray.
args:
Passed on to func.
dcherian marked this conversation as resolved.
Show resolved Hide resolved
kwargs:
Passed on to func.


Returns
-------
DataArray or Dataset

See Also
--------
dask.array.map_blocks
"""

def _wrapper(func, obj, to_array, args, kwargs):
if to_array:
obj = _to_array(obj)

result = func(obj, *args, **kwargs)

for name, index in result.indexes.items():
if name in obj.indexes:
if len(index) != len(obj.indexes[name]):
raise ValueError(
"Length of the %r dimension has changed. This is not allowed."
% name
)

to_return = _make_dict(result)

return to_return

if not dask.is_dask_collection(obj):
raise ValueError(
"map_blocks can only be used with dask-backed DataArrays. Use .chunk() to convert to a Dask array."
)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(obj, DataArray):
dataset = obj._to_temp_dataset()
input_is_array = True
else:
dataset = obj
input_is_array = False

template = infer_template(func, obj, *args, **kwargs)
if isinstance(template, DataArray):
result_is_array = True
template = template._to_temp_dataset()
else:
result_is_array = False
dcherian marked this conversation as resolved.
Show resolved Hide resolved

# If two different variables have different chunking along the same dim
# .chunks will raise an error.
input_chunks = dataset.chunks

indexes = dict(dataset.indexes)
for dim in template.indexes:
if dim not in indexes:
indexes[dim] = template.indexes[dim]
dcherian marked this conversation as resolved.
Show resolved Hide resolved

graph = {}
gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset))
dcherian marked this conversation as resolved.
Show resolved Hide resolved

# map dims to list of chunk indexes
ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks}
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# mapping from chunk index to slice bounds
chunk_index_bounds = {
dim: np.cumsum((0,) + input_chunks[dim]) for dim in input_chunks
dcherian marked this conversation as resolved.
Show resolved Hide resolved
}

# iterate over all possible chunk combinations
for v in itertools.product(*ichunk.values()):
chunk_index_dict = dict(zip(dataset.dims, v))

# this will become [[name1, variable1],
# [name2, variable2],
# ...]
# which is passed to dict and then to Dataset
data_vars = []
coords = []

for name, variable in dataset.variables.items():
# make a task that creates tuple of (dims, chunk)
if dask.is_dask_collection(variable.data):
var_dask_keys = variable.__dask_keys__()

# recursively index into dask_keys nested list to get chunk
chunk = var_dask_keys
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for dim in variable.dims:
chunk = chunk[chunk_index_dict[dim]]

task_name = ("tuple-" + dask.base.tokenize(chunk),) + v
dcherian marked this conversation as resolved.
Show resolved Hide resolved
graph[task_name] = (tuple, [variable.dims, chunk])
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
# numpy array with possibly chunked dimensions
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# index into variable appropriately
subsetter = dict()
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for dim in variable.dims:
if dim in chunk_index_dict:
which_chunk = chunk_index_dict[dim]
subsetter[dim] = slice(
chunk_index_bounds[dim][which_chunk],
chunk_index_bounds[dim][which_chunk + 1],
)

subset = variable.isel(subsetter)
task_name = (name + dask.base.tokenize(subset),) + v
graph[task_name] = (tuple, [subset.dims, subset])

# this task creates dict mapping variable name to above tuple
if name in dataset.data_vars:
data_vars.append([name, task_name])
if name in dataset.coords:
coords.append([name, task_name])
dcherian marked this conversation as resolved.
Show resolved Hide resolved

from_wrapper = (gname,) + v
graph[from_wrapper] = (
_wrapper,
func,
(Dataset, (dict, data_vars), (dict, coords), dataset.attrs),
input_is_array,
args,
kwargs,
)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

# mapping from variable name to dask graph key
var_key_map = {}
for name, variable in template.variables.items():
dcherian marked this conversation as resolved.
Show resolved Hide resolved
var_dims = variable.dims
# cannot tokenize "name" because the hash of <this-array> is not invariant!
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# This happens when the user function does not set a name on the returned DataArray
gname_l = "%s-%s" % (gname, name)
var_key_map[name] = gname_l

key = (gname_l,)
for dim in var_dims:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
if dim in chunk_index_dict:
key += (chunk_index_dict[dim],)
else:
# unchunked dimensions in the input have one chunk in the result
key += (0,)

graph[key] = (operator.getitem, from_wrapper, name)

graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset])

result = Dataset()
for var, key in var_key_map.items():
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# indexes need to be known
# otherwise compute is called when DataArray is created
if var in indexes:
result[var] = indexes[var]
continue

dims = template[var].dims
var_chunks = []
for dim in dims:
if dim in input_chunks:
var_chunks.append(input_chunks[dim])
else:
if dim in indexes:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
var_chunks.append((len(indexes[dim]),))

data = dask.array.Array(
graph, name=key, chunks=var_chunks, dtype=template[var].dtype
)
result[var] = DataArray(data=data, dims=dims, name=var)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if result_is_array:
result = _to_array(result)

return result
74 changes: 74 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,77 @@ def test_dask_layers_and_dependencies():
assert set(x.foo.__dask_graph__().dependencies).issuperset(
ds.__dask_graph__().dependencies
)


def make_da():
return xr.DataArray(
np.ones((10, 20)),
dims=["x", "y"],
coords={"x": np.arange(10), "y": np.arange(100, 120)},
dcherian marked this conversation as resolved.
Show resolved Hide resolved
name="a",
).chunk({"x": 4, "y": 5})
dcherian marked this conversation as resolved.
Show resolved Hide resolved


def make_ds():
map_ds = xr.Dataset()
map_ds["a"] = map_da
map_ds["b"] = map_ds.a + 50
map_ds["c"] = map_ds.x + 20
map_ds = map_ds.chunk({"x": 4, "y": 5})
map_ds["d"] = ("z", [1, 1, 1, 1])
map_ds["z"] = [0, 1, 2, 3]
map_ds["e"] = map_ds.x + map_ds.y + map_ds.z
map_ds.attrs["test"] = "test"

return map_ds


# work around mypy error
# xarray/tests/test_dask.py:888: error: Dict entry 0 has incompatible type "str": "int"; expected "Hashable": "Union[None, Number, Tuple[Number, ...]]"
dcherian marked this conversation as resolved.
Show resolved Hide resolved
map_da = make_da()
map_ds = make_ds()
dcherian marked this conversation as resolved.
Show resolved Hide resolved


def simple_func(obj):
result = obj.x + 5 * obj.y
return result


def complicated_func(obj):
new = obj.copy()
new = (
new[["a", "b"]]
.rename({"a": "new_var1"})
.expand_dims(k=[0, 1, 2])
.transpose("k", "y", "x")
)
new["b"] = new.b.astype("int32")
return new


def test_map_blocks_error():
def bad_func(darray):
return (darray * darray.x + 5 * darray.y)[:1, :1]

with raises_regex(ValueError, "Length of the.* has changed."):
xr.map_blocks(bad_func, map_da).compute()


@pytest.mark.parametrize(
"func, obj",
[[simple_func, map_da], [simple_func, map_ds], [complicated_func, map_ds]],
)
def test_map_blocks(func, obj):

actual = xr.map_blocks(func, obj)
expected = func(obj)
xr.testing.assert_equal(expected, actual)
dcherian marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("obj", [map_da, map_ds])
def test_map_blocks_args(obj):
import operator

expected = obj + 10
actual = xr.map_blocks(operator.add, obj, 10)
xr.testing.assert_equal(expected, actual)
dcherian marked this conversation as resolved.
Show resolved Hide resolved