Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
2x~5x speed up for isel() in most cases (pydata#3533)
Browse files Browse the repository at this point in the history
* Speed up isel in most cases

* What's New

* Trivial

* Use _replace

* isort

* Code review

* What's New

* mypy
  • Loading branch information
crusaderky authored Dec 5, 2019
1 parent cf17317 commit 87a25b6
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 14 deletions.
8 changes: 5 additions & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ Bug fixes
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
By `Deepak Cherian <https://github.com/dcherian>`_.


Documentation
~~~~~~~~~~~~~
- Switch doc examples to use nbsphinx and replace sphinx_gallery with
Expand All @@ -58,8 +57,10 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~


- 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`,
:py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int,
slice, list of int, scalar ndarray, or 1-dimensional ndarray.
(:pull:`3533`) by `Guido Imperiale <https://github.com/crusaderky>`_.
- Removed internal method ``Dataset._from_vars_and_coord_names``,
which was dominated by ``Dataset._construct_direct``. (:pull:`3565`)
By `Maximilian Roos <https://github.com/max-sixty>`_
Expand Down Expand Up @@ -190,6 +191,7 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~

- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
(:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`)
by `Justus Magin <https://github.com/keewis>`_.
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/cftime_offsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import re
from datetime import timedelta
from distutils.version import LooseVersion
from functools import partial
from typing import ClassVar, Optional

Expand All @@ -50,7 +51,6 @@
from ..core.pdcompat import count_not_none
from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso
from .times import format_cftime_datetime
from distutils.version import LooseVersion


def get_date_type(calendar):
Expand Down
26 changes: 23 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import Indexes, propagate_indexes, default_indexes
from .indexes import Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
from .options import OPTIONS
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
Expand Down Expand Up @@ -1027,8 +1028,27 @@ def isel(
DataArray.sel
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers)
return self._from_temp_dataset(ds)
if any(is_fancy_indexer(idx) for idx in indexers.values()):
ds = self._to_temp_dataset()._isel_fancy(indexers, drop=drop)
return self._from_temp_dataset(ds)

# Much faster algorithm for when all indexers are ints, slices, one-dimensional
# lists, or zero or one-dimensional np.ndarray's

variable = self._variable.isel(indexers)

coords = {}
for coord_name, coord_value in self._coords.items():
coord_indexers = {
k: v for k, v in indexers.items() if k in coord_value.dims
}
if coord_indexers:
coord_value = coord_value.isel(coord_indexers)
if drop and coord_value.ndim == 0:
continue
coords[coord_name] = coord_value

return self._replace(variable=variable, coords=coords)

def sel(
self,
Expand Down
45 changes: 44 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
propagate_indexes,
roll_index,
)
from .indexing import is_fancy_indexer
from .merge import (
dataset_merge_method,
dataset_update_method,
Expand All @@ -78,8 +79,8 @@
Default,
Frozen,
SortedKeysDict,
_default,
_check_inplace,
_default,
decode_numpy_dict_values,
either_dict_or_kwargs,
hashable,
Expand Down Expand Up @@ -1907,6 +1908,48 @@ def isel(
DataArray.isel
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
if any(is_fancy_indexer(idx) for idx in indexers.values()):
return self._isel_fancy(indexers, drop=drop)

# Much faster algorithm for when all indexers are ints, slices, one-dimensional
# lists, or zero or one-dimensional np.ndarray's
invalid = indexers.keys() - self.dims.keys()
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

variables = {}
dims: Dict[Hashable, Tuple[int, ...]] = {}
coord_names = self._coord_names.copy()
indexes = self._indexes.copy() if self._indexes is not None else None

for var_name, var_value in self._variables.items():
var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims}
if var_indexers:
var_value = var_value.isel(var_indexers)
if drop and var_value.ndim == 0 and var_name in coord_names:
coord_names.remove(var_name)
if indexes:
indexes.pop(var_name, None)
continue
if indexes and var_name in indexes:
if var_value.ndim == 1:
indexes[var_name] = var_value.to_index()
else:
del indexes[var_name]
variables[var_name] = var_value
dims.update(zip(var_value.dims, var_value.shape))

return self._construct_direct(
variables=variables,
coord_names=coord_names,
dims=dims,
attrs=self._attrs,
indexes=indexes,
encoding=self._encoding,
file_obj=self._file_obj,
)

def _isel_fancy(self, indexers: Mapping[Hashable, Any], *, drop: bool) -> "Dataset":
# Note: we need to preserve the original indexers variable in order to merge the
# coords below
indexers_list = list(self._validate_indexers(indexers))
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid
import pkg_resources
from collections import OrderedDict
from functools import partial
from html import escape

from .formatting import inline_variable_array_repr, short_data_repr
import pkg_resources

from .formatting import inline_variable_array_repr, short_data_repr

CSS_FILE_PATH = "/".join(("static", "css", "style.css"))
CSS_STYLE = pkg_resources.resource_string("xarray", CSS_FILE_PATH).decode("utf8")
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,19 @@ def posify_mask_indexer(indexer):
return type(indexer)(key)


def is_fancy_indexer(indexer: Any) -> bool:
"""Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or
1-dimensional ndarray; in all other cases return True
"""
if isinstance(indexer, (int, slice)):
return False
if isinstance(indexer, np.ndarray):
return indexer.ndim > 1
if isinstance(indexer, list):
return bool(indexer) and not isinstance(indexer[0], int)
return True


class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a NumPy array to use explicit indexing."""

Expand Down
5 changes: 4 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,10 @@ def _broadcast_indexes_outer(self, key):
k = k.data
if not isinstance(k, BASIC_INDEXING_TYPES):
k = np.asarray(k)
if k.dtype.kind == "b":
if k.size == 0:
# Slice by empty list; numpy could not infer the dtype
k = k.astype(int)
elif k.dtype.kind == "b":
(k,) = np.nonzero(k)
new_key.append(k)

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xarray.testing import assert_chunks_equal
from xarray.tests import mock

from ..core.duck_array_ops import lazy_array_equiv
from . import (
assert_allclose,
assert_array_equal,
Expand All @@ -25,7 +26,6 @@
raises_regex,
requires_scipy_or_netCDF4,
)
from ..core.duck_array_ops import lazy_array_equiv
from .test_backends import create_tmp_file

dask = pytest.importorskip("dask")
Expand Down
1 change: 0 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from xarray.core.common import full_like
from xarray.core.indexes import propagate_indexes
from xarray.core.utils import is_scalar

from xarray.tests import (
LooseVersion,
ReturnItem,
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
NumpyInterpolator,
ScipyInterpolator,
SplineInterpolator,
get_clean_interp_index,
_get_nan_block_lengths,
get_clean_interp_index,
)
from xarray.core.pycompat import dask_array_type
from xarray.tests import (
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,26 @@ def test_items(self):
def test_getitem_basic(self):
v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]])

# int argument
v_new = v[0]
assert v_new.dims == ("y",)
assert_array_equal(v_new, v._data[0])

# slice argument
v_new = v[:2]
assert v_new.dims == ("x", "y")
assert_array_equal(v_new, v._data[:2])

# list arguments
v_new = v[[0]]
assert v_new.dims == ("x", "y")
assert_array_equal(v_new, v._data[[0]])

v_new = v[[]]
assert v_new.dims == ("x", "y")
assert_array_equal(v_new, v._data[[]])

# dict arguments
v_new = v[dict(x=0)]
assert v_new.dims == ("y",)
assert_array_equal(v_new, v._data[0])
Expand Down Expand Up @@ -1196,6 +1216,8 @@ def test_isel(self):
assert_identical(v.isel(time=0), v[0])
assert_identical(v.isel(time=slice(0, 3)), v[:3])
assert_identical(v.isel(x=0), v[:, 0])
assert_identical(v.isel(x=[0, 2]), v[:, [0, 2]])
assert_identical(v.isel(time=[]), v[[]])
with raises_regex(ValueError, "do not exist"):
v.isel(not_a_dim=0)

Expand Down

0 comments on commit 87a25b6

Please sign in to comment.