Skip to content

Commit

Permalink
Improve concat performance (#7824)
Browse files Browse the repository at this point in the history
* 1. var_idx very slow

* 2. slow any

* Add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 3. Slow array_type called multiple times

* 4. Can use fastpath for variable.concat?

* 5. slow init of pd.unique

* typos

* Update concat.py

* Update merge.py

* 6. Avoid recalculating in loops

* 7. No need to transpose 1d arrays.

* 8. speed up dask_dataframe

* Update dataset.py

* Update dataset.py

* Update dataset.py

* Add dask combine test with many variables

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update combine.py

* Update combine.py

* Update combine.py

* list not needed

* dim is usually string, might be faster to check for that

* first_var.dims doesn't change and can be defined 1 time

* mask bad points rather than append good points

* reduce duplicated code

* don't think id() is required here.

* get dtype directly instead of through result_dtype

* seems better to delete rather than append,

* use internal fastpath if it's a dataset, values should be fine then

* Change isinstance order.

* use fastpath if already xarray objtect

* Update variable.py

* Update dtypes.py

* typing fixes

* more typing fixes

* test undoing as_compatible_data

* undo concat_dim_length deletion

* Update xarray/core/concat.py

* Remove .copy and sum

* Update concat.py

* Use OrderedSet

* Apply suggestions from code review

* Update whats-new.rst

* Update xarray/core/concat.py

* no need to check arrays if cupy isnt even installed

* Update whats-new.rst

* Add concat comment

* minimize diff

* revert sketchy

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Jun 2, 2023
1 parent 960f15c commit c9d89e2
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 51 deletions.
43 changes: 42 additions & 1 deletion asv_bench/benchmarks/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,49 @@

import xarray as xr

from . import requires_dask

class Combine:

class Combine1d:
"""Benchmark concatenating and merging large datasets"""

def setup(self) -> None:
"""Create 2 datasets with two different variables"""

t_size = 8000
t = np.arange(t_size)
data = np.random.randn(t_size)

self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))})
self.dsA1 = xr.Dataset(
{"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))}
)

def time_combine_by_coords(self) -> None:
"""Also has to load and arrange t coordinate"""
datasets = [self.dsA0, self.dsA1]

xr.combine_by_coords(datasets)


class Combine1dDask(Combine1d):
"""Benchmark concatenating and merging large datasets"""

def setup(self) -> None:
"""Create 2 datasets with two different variables"""
requires_dask()

t_size = 8000
t = np.arange(t_size)
var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk()

data_vars = {f"long_name_{v}": ("T", var) for v in range(500)}

self.dsA0 = xr.Dataset(data_vars, coords={"T": t})
self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size})


class Combine3d:
"""Benchmark concatenating and merging large datasets"""

def setup(self):
Expand Down
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ Deprecations

Performance
~~~~~~~~~~~

- Improve concatenation performance (:issue:`7833`, :pull:`7824`).
By `Jimmy Westling <https://github.com/illviljan>`_.

Bug fixes
~~~~~~~~~
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,18 +970,18 @@ def combine_by_coords(

# Perform the multidimensional combine on each group of data variables
# before merging back together
concatenated_grouped_by_data_vars = []
for vars, datasets_with_same_vars in grouped_by_vars:
concatenated = _combine_single_variable_hypercube(
list(datasets_with_same_vars),
concatenated_grouped_by_data_vars = tuple(
_combine_single_variable_hypercube(
tuple(datasets_with_same_vars),
fill_value=fill_value,
data_vars=data_vars,
coords=coords,
compat=compat,
join=join,
combine_attrs=combine_attrs,
)
concatenated_grouped_by_data_vars.append(concatenated)
for vars, datasets_with_same_vars in grouped_by_vars
)

return merge(
concatenated_grouped_by_data_vars,
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, .
int or tuple of int
Axis number or numbers corresponding to the given dimensions.
"""
if isinstance(dim, Iterable) and not isinstance(dim, str):
if not isinstance(dim, str) and isinstance(dim, Iterable):
return tuple(self._get_axis_num(d) for d in dim)
else:
return self._get_axis_num(dim)
Expand Down
34 changes: 21 additions & 13 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Union, cast, overload

import numpy as np
import pandas as pd

from xarray.core import dtypes, utils
Expand Down Expand Up @@ -517,7 +518,7 @@ def _dataset_concat(
if variables_to_merge:
grouped = {
k: v
for k, v in collect_variables_and_indexes(list(datasets)).items()
for k, v in collect_variables_and_indexes(datasets).items()
if k in variables_to_merge
}
merged_vars, merged_indexes = merge_collected(
Expand All @@ -543,7 +544,7 @@ def ensure_common_dims(vars, concat_dim_lengths):
# ensure each variable with the given name shares the same
# dimensions and the same shape for all of them except along the
# concat dimension
common_dims = tuple(pd.unique([d for v in vars for d in v.dims]))
common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims))
if dim not in common_dims:
common_dims = (dim,) + common_dims
for var, dim_len in zip(vars, concat_dim_lengths):
Expand All @@ -568,38 +569,45 @@ def get_indexes(name):
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
concat_index = list(range(sum(concat_dim_lengths)))
file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
concat_index = np.arange(file_start_indexes[-1])
concat_index_size = concat_index.size
variable_index_mask = np.ones(concat_index_size, dtype=bool)

# stack up each variable and/or index to fill-out the dataset (in order)
# n.b. this loop preserves variable order, needed for groupby.
ndatasets = len(datasets)
for name in vars_order:
if name in concat_over and name not in result_indexes:
variables = []
variable_index = []
# Initialize the mask to all True then set False if any name is missing in
# the datasets:
variable_index_mask.fill(True)
var_concat_dim_length = []
for i, ds in enumerate(datasets):
if name in ds.variables:
variables.append(ds[name].variable)
# add to variable index, needed for reindexing
var_idx = [
sum(concat_dim_lengths[:i]) + k
for k in range(concat_dim_lengths[i])
]
variable_index.extend(var_idx)
var_concat_dim_length.append(len(var_idx))
var_concat_dim_length.append(concat_dim_lengths[i])
else:
# raise if coordinate not in all datasets
if name in coord_names:
raise ValueError(
f"coordinate {name!r} not present in all datasets."
)

# Mask out the indexes without the name:
start = file_start_indexes[i]
end = file_start_indexes[i + 1]
variable_index_mask[slice(start, end)] = False

variable_index = concat_index[variable_index_mask]
vars = ensure_common_dims(variables, var_concat_dim_length)

# Try to concatenate the indexes, concatenate the variables when no index
# is found on all datasets.
indexes: list[Index] = list(get_indexes(name))
if indexes:
if len(indexes) < len(datasets):
if len(indexes) < ndatasets:
raise ValueError(
f"{name!r} must have either an index or no index in all datasets, "
f"found {len(indexes)}/{len(datasets)} datasets with an index."
Expand All @@ -623,7 +631,7 @@ def get_indexes(name):
vars, dim, positions, combine_attrs=combine_attrs
)
# reindex if variable is not present in all datasets
if len(variable_index) < len(concat_index):
if len(variable_index) < concat_index_size:
combined_var = reindex_variables(
variables={name: combined_var},
dim_pos_indexers={
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def __init__(
)

if isinstance(coords, Dataset):
coords = coords.variables
coords = coords._variables

variables, coord_names, dims, indexes, _ = merge_data_and_coords(
data_vars, coords, compat="broadcast_equals"
Expand Down Expand Up @@ -1399,8 +1399,8 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self.variables[k].dims) <= needed_dims:
coords[k] = self.variables[k]
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))

Expand Down
14 changes: 8 additions & 6 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __eq__(self, other):
# instead of following NumPy's own type-promotion rules. These type promotion
# rules match pandas instead. For reference, see the NumPy type hierarchy:
# https://numpy.org/doc/stable/reference/arrays.scalars.html
PROMOTE_TO_OBJECT = [
{np.number, np.character}, # numpy promotes to character
{np.bool_, np.character}, # numpy promotes to character
{np.bytes_, np.unicode_}, # numpy promotes to unicode
]
PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
(np.number, np.character), # numpy promotes to character
(np.bool_, np.character), # numpy promotes to character
(np.bytes_, np.unicode_), # numpy promotes to unicode
)


def maybe_promote(dtype):
Expand Down Expand Up @@ -156,7 +156,9 @@ def is_datetime_like(dtype):
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def result_type(*arrays_and_dtypes):
def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
) -> np.dtype:
"""Like np.result_type, but with type promotion rules matching pandas.
Examples of changed behavior:
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays):
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ def filter_indexes_from_coords(
of coordinate names.
"""
filtered_indexes: dict[Any, Index] = dict(**indexes)
filtered_indexes: dict[Any, Index] = dict(indexes)

index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set)
for name, idx in indexes.items():
Expand Down
27 changes: 19 additions & 8 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ def _assert_prioritized_valid(


def merge_collected(
grouped: dict[Hashable, list[MergeElement]],
grouped: dict[Any, list[MergeElement]],
prioritized: Mapping[Any, MergeElement] | None = None,
compat: CompatOptions = "minimal",
combine_attrs: CombineAttrsOptions = "override",
equals: dict[Hashable, bool] | None = None,
equals: dict[Any, bool] | None = None,
) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
"""Merge dicts of variables, while resolving conflicts appropriately.
Expand Down Expand Up @@ -306,7 +306,7 @@ def merge_collected(


def collect_variables_and_indexes(
list_of_mappings: list[DatasetLike],
list_of_mappings: Iterable[DatasetLike],
indexes: Mapping[Any, Any] | None = None,
) -> dict[Hashable, list[MergeElement]]:
"""Collect variables and indexes from list of mappings of xarray objects.
Expand Down Expand Up @@ -556,7 +556,12 @@ def merge_coords(
return variables, out_indexes


def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"):
def merge_data_and_coords(
data_vars: Mapping[Any, Any],
coords: Mapping[Any, Any],
compat: CompatOptions = "broadcast_equals",
join: JoinOptions = "outer",
) -> _MergeResult:
"""Used in Dataset.__init__."""
indexes, coords = _create_indexes_from_coords(coords, data_vars)
objects = [data_vars, coords]
Expand All @@ -570,7 +575,9 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou
)


def _create_indexes_from_coords(coords, data_vars=None):
def _create_indexes_from_coords(
coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None
) -> tuple[dict, dict]:
"""Maybe create default indexes from a mapping of coordinates.
Return those indexes and updated coordinates.
Expand Down Expand Up @@ -605,7 +612,11 @@ def _create_indexes_from_coords(coords, data_vars=None):
return indexes, updated_coords


def assert_valid_explicit_coords(variables, dims, explicit_coords):
def assert_valid_explicit_coords(
variables: Mapping[Any, Any],
dims: Mapping[Any, int],
explicit_coords: Iterable[Hashable],
) -> None:
"""Validate explicit coordinate names/dims.
Raise a MergeError if an explicit coord shares a name with a dimension
Expand Down Expand Up @@ -688,7 +699,7 @@ def merge_core(
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
priority_arg: int | None = None,
explicit_coords: Sequence | None = None,
explicit_coords: Iterable[Hashable] | None = None,
indexes: Mapping[Any, Any] | None = None,
fill_value: object = dtypes.NA,
) -> _MergeResult:
Expand Down Expand Up @@ -1035,7 +1046,7 @@ def dataset_merge_method(
# method due for backwards compatibility
# TODO: consider deprecating it?

if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str):
if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable):
overwrite_vars = set(overwrite_vars)
else:
overwrite_vars = {overwrite_vars}
Expand Down
16 changes: 14 additions & 2 deletions xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,26 @@ def __init__(self, mod: ModType) -> None:
self.available = duck_array_module is not None


_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {}


def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule:
if mod not in _cached_duck_array_modules:
duckmod = DuckArrayModule(mod)
_cached_duck_array_modules[mod] = duckmod
return duckmod
else:
return _cached_duck_array_modules[mod]


def array_type(mod: ModType) -> DuckArrayTypes:
"""Quick wrapper to get the array class of the module."""
return DuckArrayModule(mod).type
return _get_cached_duck_array_module(mod).type


def mod_version(mod: ModType) -> Version:
"""Quick wrapper to get the version of the module."""
return DuckArrayModule(mod).version
return _get_cached_duck_array_module(mod).version


def is_dask_collection(x):
Expand Down
Loading

0 comments on commit c9d89e2

Please sign in to comment.