Skip to content

Commit

Permalink
Refactor DaskArrayAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Jul 16, 2017
1 parent 33c51d3 commit 03a336f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 17 deletions.
68 changes: 51 additions & 17 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,41 @@ def _expand_slice(slice_, size):
return np.arange(*slice_.indices(size))


def maybe_convert_to_slice(indexer, size):
"""Convert an indexer into an equivalent slice object, if possible.
Arguments
---------
indexer : int, slice or np.ndarray
If a numpy array, must have integer dtype.
size : integer
Integer size of the dimension to be indexed.
"""
if indexer.ndim != 1 or not isinstance(indexer, np.ndarray):
return indexer

if indexer.size == 0:
return slice(0, 0)

if indexer.min() < -size or indexer.max() >= size:
raise IndexError(
'indexer has elements out of bounds for axis of size {}: {}'
.format(size, indexer))

indexer = np.where(indexer < 0, indexer + size, indexer)
if indexer.size == 1:
i = int(indexer[0])
return slice(i, i + 1)

start = int(indexer[0])
step = int(indexer[1] - start)
stop = start + step * indexer.size
guess = slice(start, stop, step)
if np.array_equal(_expand_slice(guess, size), indexer):
return guess
return indexer


def orthogonal_indexer(key, shape):
"""Given a key for orthogonal array indexing, returns an equivalent key
suitable for indexing a numpy.ndarray with fancy indexing.
Expand Down Expand Up @@ -473,24 +508,23 @@ def __init__(self, array):
self.array = array

def __getitem__(self, key):
""" key: tuple of Variable, slice, integer """
# basic or orthogonal indexing
if all(isinstance(k, (integer_types, slice)) or k.squeeze().ndim <= 1
for k in key):
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
if hasattr(subkey, 'squeeze'):
subkey = subkey.squeeze()
if subkey.ndim == 0: # make at least 1-d array
subkey = subkey.flatten()
value = value[(slice(None),) * axis + (subkey,)]
return value
""" key: tuple of ndarray, slice, integer """
if all(isinstance(k, integer_types + (slice,)) for k in key):
# basic indexing
return self.array[key]
elif all(k.shape == (1,) * (i - 1) + (max(k.shape),) + (1,) * (i - 1)
for i, k in enumerate(key)
if isinstance(k, np.ndarray)):
# orthogonal indexing
# dask only supports one list in an indexer, so convert to slice if
# possible
key = tuple(maybe_convert_to_slice(np.ravel(k), size)
for k, size in zip(key, self.shape))
return self.array[key]
# TODO: handle point-wise indexing with vindex
else:
# TODO Dask does not support nd-array indexing.
# flatten() -> .vindex[] -> reshape() should be used
# instead of `.load()`
value = np.asarray(self.array)[key]
return value
raise IndexError(
'dask does not support fancy indexing with key: {}'.format(key))


class PandasIndexAdapter(utils.NDArrayMixin):
Expand Down
24 changes: 24 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ def test_expanded_indexer(self):
with self.assertRaisesRegexp(IndexError, 'too many indices'):
indexing.expanded_indexer(I[1, 2, 3], 2)

def test_maybe_convert_to_slice(self):

cases = [
(1,),
(1, 1),
(1, 2),
(10,),
(0, 10),
(5, 10),
(5, 8),
(None, 5),
(None, -3),
(0, 10, 2),
(10, None, -1),
(7, 3, -2),
]
for case in cases:
slice_obj = slice(*case)
base_array = np.arange(*slice_obj.indices(10))
for array in [base_array, base_array - 10]:
actual = indexing.maybe_convert_to_slice(array, 10)
self.assertArrayEqual(np.arange(10)[actual],
np.arange(10)[slice_obj])

def test_orthogonal_indexer(self):
x = np.random.randn(10, 11, 12, 13, 14)
y = np.arange(5)
Expand Down

0 comments on commit 03a336f

Please sign in to comment.