From 87a25b64898c94ea1e2a2e7a06d31ef602b116bf Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 5 Dec 2019 16:39:39 +0000 Subject: [PATCH] 2x~5x speed up for isel() in most cases (#3533) * Speed up isel in most cases * What's New * Trivial * Use _replace * isort * Code review * What's New * mypy --- doc/whats-new.rst | 8 +++--- xarray/coding/cftime_offsets.py | 2 +- xarray/core/dataarray.py | 26 ++++++++++++++++--- xarray/core/dataset.py | 45 ++++++++++++++++++++++++++++++++- xarray/core/formatting_html.py | 4 +-- xarray/core/indexing.py | 13 ++++++++++ xarray/core/variable.py | 5 +++- xarray/tests/test_dask.py | 2 +- xarray/tests/test_dataarray.py | 1 - xarray/tests/test_missing.py | 2 +- xarray/tests/test_variable.py | 22 ++++++++++++++++ 11 files changed, 116 insertions(+), 14 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8930947a2a6..d4d8ab8f3e5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,7 +37,6 @@ Bug fixes - Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`) By `Deepak Cherian `_. - Documentation ~~~~~~~~~~~~~ - Switch doc examples to use nbsphinx and replace sphinx_gallery with @@ -58,8 +57,10 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ - - +- 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`, + :py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int, + slice, list of int, scalar ndarray, or 1-dimensional ndarray. + (:pull:`3533`) by `Guido Imperiale `_. - Removed internal method ``Dataset._from_vars_and_coord_names``, which was dominated by ``Dataset._construct_direct``. (:pull:`3565`) By `Maximilian Roos `_ @@ -190,6 +191,7 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ + - Added integration tests against `pint `_. (:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`) by `Justus Magin `_. diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 8471ed1a558..eeb68508527 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -42,6 +42,7 @@ import re from datetime import timedelta +from distutils.version import LooseVersion from functools import partial from typing import ClassVar, Optional @@ -50,7 +51,6 @@ from ..core.pdcompat import count_not_none from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso from .times import format_cftime_datetime -from distutils.version import LooseVersion def get_date_type(calendar): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 64f21b0eb01..20de0cffbc2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -50,7 +50,8 @@ ) from .dataset import Dataset, split_indexes from .formatting import format_item -from .indexes import Indexes, propagate_indexes, default_indexes +from .indexes import Indexes, default_indexes, propagate_indexes +from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, _extract_indexes_from_coords from .options import OPTIONS from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs @@ -1027,8 +1028,27 @@ def isel( DataArray.sel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") - ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers) - return self._from_temp_dataset(ds) + if any(is_fancy_indexer(idx) for idx in indexers.values()): + ds = self._to_temp_dataset()._isel_fancy(indexers, drop=drop) + return self._from_temp_dataset(ds) + + # Much faster algorithm for when all indexers are ints, slices, one-dimensional + # lists, or zero or one-dimensional np.ndarray's + + variable = self._variable.isel(indexers) + + coords = {} + for coord_name, coord_value in self._coords.items(): + coord_indexers = { + k: v for k, v in indexers.items() if k in coord_value.dims + } + if coord_indexers: + coord_value = coord_value.isel(coord_indexers) + if drop and coord_value.ndim == 0: + continue + coords[coord_name] = coord_value + + return self._replace(variable=variable, coords=coords) def sel( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 61dde6a393b..5926fd4ff36 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -66,6 +66,7 @@ propagate_indexes, roll_index, ) +from .indexing import is_fancy_indexer from .merge import ( dataset_merge_method, dataset_update_method, @@ -78,8 +79,8 @@ Default, Frozen, SortedKeysDict, - _default, _check_inplace, + _default, decode_numpy_dict_values, either_dict_or_kwargs, hashable, @@ -1907,6 +1908,48 @@ def isel( DataArray.isel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + if any(is_fancy_indexer(idx) for idx in indexers.values()): + return self._isel_fancy(indexers, drop=drop) + + # Much faster algorithm for when all indexers are ints, slices, one-dimensional + # lists, or zero or one-dimensional np.ndarray's + invalid = indexers.keys() - self.dims.keys() + if invalid: + raise ValueError("dimensions %r do not exist" % invalid) + + variables = {} + dims: Dict[Hashable, Tuple[int, ...]] = {} + coord_names = self._coord_names.copy() + indexes = self._indexes.copy() if self._indexes is not None else None + + for var_name, var_value in self._variables.items(): + var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims} + if var_indexers: + var_value = var_value.isel(var_indexers) + if drop and var_value.ndim == 0 and var_name in coord_names: + coord_names.remove(var_name) + if indexes: + indexes.pop(var_name, None) + continue + if indexes and var_name in indexes: + if var_value.ndim == 1: + indexes[var_name] = var_value.to_index() + else: + del indexes[var_name] + variables[var_name] = var_value + dims.update(zip(var_value.dims, var_value.shape)) + + return self._construct_direct( + variables=variables, + coord_names=coord_names, + dims=dims, + attrs=self._attrs, + indexes=indexes, + encoding=self._encoding, + file_obj=self._file_obj, + ) + + def _isel_fancy(self, indexers: Mapping[Hashable, Any], *, drop: bool) -> "Dataset": # Note: we need to preserve the original indexers variable in order to merge the # coords below indexers_list = list(self._validate_indexers(indexers)) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index dbebbcf4fbe..8ceda8bfbfa 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -1,11 +1,11 @@ import uuid -import pkg_resources from collections import OrderedDict from functools import partial from html import escape -from .formatting import inline_variable_array_repr, short_data_repr +import pkg_resources +from .formatting import inline_variable_array_repr, short_data_repr CSS_FILE_PATH = "/".join(("static", "css", "style.css")) CSS_STYLE = pkg_resources.resource_string("xarray", CSS_FILE_PATH).decode("utf8") diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f48c9e72af1..8e851b39c3e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1213,6 +1213,19 @@ def posify_mask_indexer(indexer): return type(indexer)(key) +def is_fancy_indexer(indexer: Any) -> bool: + """Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or + 1-dimensional ndarray; in all other cases return True + """ + if isinstance(indexer, (int, slice)): + return False + if isinstance(indexer, np.ndarray): + return indexer.ndim > 1 + if isinstance(indexer, list): + return bool(indexer) and not isinstance(indexer[0], int) + return True + + class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a NumPy array to use explicit indexing.""" diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 773dcef0aa1..aa04cffb5ea 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -617,7 +617,10 @@ def _broadcast_indexes_outer(self, key): k = k.data if not isinstance(k, BASIC_INDEXING_TYPES): k = np.asarray(k) - if k.dtype.kind == "b": + if k.size == 0: + # Slice by empty list; numpy could not infer the dtype + k = k.astype(int) + elif k.dtype.kind == "b": (k,) = np.nonzero(k) new_key.append(k) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 4c1f317342f..f3b10e3370c 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -16,6 +16,7 @@ from xarray.testing import assert_chunks_equal from xarray.tests import mock +from ..core.duck_array_ops import lazy_array_equiv from . import ( assert_allclose, assert_array_equal, @@ -25,7 +26,6 @@ raises_regex, requires_scipy_or_netCDF4, ) -from ..core.duck_array_ops import lazy_array_equiv from .test_backends import create_tmp_file dask = pytest.importorskip("dask") diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a1e34abd0d5..f957316d8ac 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -16,7 +16,6 @@ from xarray.core.common import full_like from xarray.core.indexes import propagate_indexes from xarray.core.utils import is_scalar - from xarray.tests import ( LooseVersion, ReturnItem, diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 0b410383a34..1cd0319a9a5 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -9,8 +9,8 @@ NumpyInterpolator, ScipyInterpolator, SplineInterpolator, - get_clean_interp_index, _get_nan_block_lengths, + get_clean_interp_index, ) from xarray.core.pycompat import dask_array_type from xarray.tests import ( diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 5b5aa1a523f..1d83e16a5bd 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1156,6 +1156,26 @@ def test_items(self): def test_getitem_basic(self): v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) + # int argument + v_new = v[0] + assert v_new.dims == ("y",) + assert_array_equal(v_new, v._data[0]) + + # slice argument + v_new = v[:2] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v._data[:2]) + + # list arguments + v_new = v[[0]] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v._data[[0]]) + + v_new = v[[]] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v._data[[]]) + + # dict arguments v_new = v[dict(x=0)] assert v_new.dims == ("y",) assert_array_equal(v_new, v._data[0]) @@ -1196,6 +1216,8 @@ def test_isel(self): assert_identical(v.isel(time=0), v[0]) assert_identical(v.isel(time=slice(0, 3)), v[:3]) assert_identical(v.isel(x=0), v[:, 0]) + assert_identical(v.isel(x=[0, 2]), v[:, [0, 2]]) + assert_identical(v.isel(time=[]), v[[]]) with raises_regex(ValueError, "do not exist"): v.isel(not_a_dim=0)