diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 49360c03b94..75eb0607936 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -72,6 +72,41 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) +def maybe_convert_to_slice(indexer, size): + """Convert an indexer into an equivalent slice object, if possible. + + Arguments + --------- + indexer : int, slice or np.ndarray + If a numpy array, must have integer dtype. + size : integer + Integer size of the dimension to be indexed. + """ + if indexer.ndim != 1 or not isinstance(indexer, np.ndarray): + return indexer + + if indexer.size == 0: + return slice(0, 0) + + if indexer.min() < -size or indexer.max() >= size: + raise IndexError( + 'indexer has elements out of bounds for axis of size {}: {}' + .format(size, indexer)) + + indexer = np.where(indexer < 0, indexer + size, indexer) + if indexer.size == 1: + i = int(indexer[0]) + return slice(i, i + 1) + + start = int(indexer[0]) + step = int(indexer[1] - start) + stop = start + step * indexer.size + guess = slice(start, stop, step) + if np.array_equal(_expand_slice(guess, size), indexer): + return guess + return indexer + + def orthogonal_indexer(key, shape): """Given a key for orthogonal array indexing, returns an equivalent key suitable for indexing a numpy.ndarray with fancy indexing. @@ -473,24 +508,23 @@ def __init__(self, array): self.array = array def __getitem__(self, key): - """ key: tuple of Variable, slice, integer """ - # basic or orthogonal indexing - if all(isinstance(k, (integer_types, slice)) or k.squeeze().ndim <= 1 - for k in key): - value = self.array - for axis, subkey in reversed(list(enumerate(key))): - if hasattr(subkey, 'squeeze'): - subkey = subkey.squeeze() - if subkey.ndim == 0: # make at least 1-d array - subkey = subkey.flatten() - value = value[(slice(None),) * axis + (subkey,)] - return value + """ key: tuple of ndarray, slice, integer """ + if all(isinstance(k, integer_types + (slice,)) for k in key): + # basic indexing + return self.array[key] + elif all(k.shape == (1,) * (i - 1) + (max(k.shape),) + (1,) * (i - 1) + for i, k in enumerate(key) + if isinstance(k, np.ndarray)): + # orthogonal indexing + # dask only supports one list in an indexer, so convert to slice if + # possible + key = tuple(maybe_convert_to_slice(np.ravel(k), size) + for k, size in zip(key, self.shape)) + return self.array[key] + # TODO: handle point-wise indexing with vindex else: - # TODO Dask does not support nd-array indexing. - # flatten() -> .vindex[] -> reshape() should be used - # instead of `.load()` - value = np.asarray(self.array)[key] - return value + raise IndexError( + 'dask does not support fancy indexing with key: {}'.format(key)) class PandasIndexAdapter(utils.NDArrayMixin): diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 3866e0511a5..6d478a68b5f 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -29,6 +29,30 @@ def test_expanded_indexer(self): with self.assertRaisesRegexp(IndexError, 'too many indices'): indexing.expanded_indexer(I[1, 2, 3], 2) + def test_maybe_convert_to_slice(self): + + cases = [ + (1,), + (1, 1), + (1, 2), + (10,), + (0, 10), + (5, 10), + (5, 8), + (None, 5), + (None, -3), + (0, 10, 2), + (10, None, -1), + (7, 3, -2), + ] + for case in cases: + slice_obj = slice(*case) + base_array = np.arange(*slice_obj.indices(10)) + for array in [base_array, base_array - 10]: + actual = indexing.maybe_convert_to_slice(array, 10) + self.assertArrayEqual(np.arange(10)[actual], + np.arange(10)[slice_obj]) + def test_orthogonal_indexer(self): x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5)