Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt map_blocks to use new Coordinates API #8560

Merged
merged 7 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates):
:py:class:`~xarray.Coordinates` object is passed, its indexes
will be added to the new created object.
indexes: dict-like, optional
Mapping of where keys are coordinate names and values are
Mapping where keys are coordinate names and values are
:py:class:`~xarray.indexes.Index` objects. If None (default),
pandas indexes will be created for each dimension coordinate.
Passing an empty dictionary will skip this default behavior.
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None # type: ignore
DaskDataFrame = None
try:
from dask.delayed import Delayed
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None # type: ignore
DaskDataFrame = None


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down
89 changes: 57 additions & 32 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@
import itertools
import operator
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np

from xarray.core.alignment import align
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index
from xarray.core.merge import merge
from xarray.core.pycompat import is_dask_collection

if TYPE_CHECKING:
from xarray.core.types import T_Xarray


class ExpectedDict(TypedDict):
shapes: dict[Hashable, int]
coords: set[Hashable]
data_vars: set[Hashable]
indexes: dict[Hashable, Index]


def unzip(iterable):
return zip(*iterable)

Expand All @@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):


def check_result_variables(
result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
result: DataArray | Dataset,
expected: ExpectedDict,
kind: Literal["coords", "data_vars"],
):
if kind == "coords":
nice_str = "coordinate"
Expand Down Expand Up @@ -254,7 +266,7 @@ def _wrapper(
args: list,
kwargs: dict,
arg_is_array: Iterable[bool],
expected: dict,
expected: ExpectedDict,
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
Expand Down Expand Up @@ -345,33 +357,45 @@ def _wrapper(
for arg in aligned
)

merged_coordinates = merge([arg.coords for arg in aligned]).coords

_, npargs = unzip(
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
)

# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
input_indexes = dict(npargs[0]._indexes)
for arg in xarray_objs[1:]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(arg._indexes)

coordinates: Coordinates
if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, aligned[0], *args, **kwargs)
template_indexes = set(template._indexes)
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template._indexes[k] for k in new_indexes})
template_coords = set(template.coords)
preserved_coord_vars = template_coords & set(merged_coordinates)
new_coord_vars = template_coords - set(merged_coordinates)

preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
# preserved_coords contains all coordinates bariables that share a dimension
# with any index variable in preserved_indexes
# Drop any unneeded vars in a second pass, this is required for e.g.
# if the mapped function were to drop a non-dimension coordinate variable.
preserved_coords = preserved_coords.drop_vars(
tuple(k for k in preserved_coords.variables if k not in template_coords)
)

coordinates = merge(
(preserved_coords, template.coords.to_dataset()[new_coord_vars])
).coords
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
}

else:
# template xarray object has been provided with proper sizes and chunk shapes
indexes = dict(template._indexes)
coordinates = template.coords
output_chunks = template.chunksizes
if not output_chunks:
raise ValueError(
Expand Down Expand Up @@ -473,6 +497,9 @@ def subset_dataset_to_block(

return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)

# variable names that depend on the computation. Currently, indexes
# cannot be modified in the mapped function, so we exclude thos
computed_variables = set(template.variables) - set(coordinates.xindexes)
# iterate over all possible chunk combinations
for chunk_tuple in itertools.product(*ichunk.values()):
# mapping from dimension name to chunk index
Expand All @@ -485,29 +512,32 @@ def subset_dataset_to_block(
for isxr, arg in zip(is_xarray, npargs)
]

# expected["shapes", "coords", "data_vars", "indexes"] are used to
# raise nice error messages in _wrapper
expected = {}
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
expected["shapes"] = {
k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
}
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
expected["indexes"] = {
dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
for dim in indexes
expected: ExpectedDict = {
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
"shapes": {
k: output_chunks[k][v]
for k, v in chunk_index.items()
if k in output_chunks
},
"data_vars": set(template.data_vars.keys()),
"coords": set(template.coords.keys()),
"indexes": {
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in coordinates.xindexes
},
}

from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)

# mapping from variable name to dask graph key
var_key_map: dict[Hashable, str] = {}
for name, variable in template.variables.items():
if name in indexes:
continue
for name in computed_variables:
variable = template.variables[name]
gname_l = f"{name}-{gname}"
var_key_map[name] = gname_l

Expand Down Expand Up @@ -543,12 +573,7 @@ def subset_dataset_to_block(
},
)

# TODO: benbovy - flexible indexes: make it work with custom indexes
# this will need to pass both indexes and coords to the Dataset constructor
result = Dataset(
coords={k: idx.to_pandas_index() for k, idx in indexes.items()},
attrs=template.attrs,
)
result = Dataset(coords=coordinates, attrs=template.attrs)

for index in result._indexes:
result[index].attrs = template[index].attrs
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj):
assert_identical(actual, template)


def test_map_blocks_roundtrip_string_index():
ds = xr.Dataset(
{"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]}
).chunk(label=1)
assert ds.label.dtype == np.dtype("<U3")

mapped = ds.map_blocks(lambda x: x, template=ds)
assert mapped.label.dtype == ds.label.dtype

mapped = ds.map_blocks(lambda x: x, template=None)
assert mapped.label.dtype == ds.label.dtype

mapped = ds.data.map_blocks(lambda x: x, template=ds.data)
assert mapped.label.dtype == ds.label.dtype

mapped = ds.data.map_blocks(lambda x: x, template=None)
assert mapped.label.dtype == ds.label.dtype


def test_map_blocks_template_convert_object():
da = make_da()
func = lambda x: x.to_dataset().isel(x=[1])
Expand Down
Loading