Skip to content

Commit

Permalink
orthogonal indexing for dask.
Browse files Browse the repository at this point in the history
  • Loading branch information
fujiisoup committed Jul 16, 2017
1 parent 17b6465 commit 33c51d3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 63 deletions.
44 changes: 19 additions & 25 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def shape(self):
return tuple(shape)

def __array__(self, dtype=None):
array = orthogonally_indexable(self.array)
array = broadcasted_indexable(self.array)
return np.asarray(array[self.key], dtype=None)

def __getitem__(self, key):
Expand Down Expand Up @@ -434,7 +434,7 @@ def __setitem__(self, key, value):
self.array[key] = value


def orthogonally_indexable(array):
def broadcasted_indexable(array):
if isinstance(array, np.ndarray):
return NumpyIndexingAdapter(array)
if isinstance(array, pd.Index):
Expand All @@ -445,24 +445,10 @@ def orthogonally_indexable(array):


class NumpyIndexingAdapter(utils.NDArrayMixin):
"""Wrap a NumPy array to use orthogonal indexing (array indexing
accesses different dimensions independently, like netCDF4-python variables)
"""Wrap a NumPy array to use broadcasted indexing
"""
# note: this object is somewhat similar to biggus.NumpyArrayAdapter in that
# it implements orthogonal indexing, except it casts to a numpy array,
# isn't lazy and supports writing values.
def __init__(self, array):
self.array = np.asarray(array)

def __array__(self, dtype=None):
return np.asarray(self.array, dtype=dtype)

def _convert_key(self, key):
key = expanded_indexer(key, self.ndim)
if any(not isinstance(k, integer_types + (slice,)) for k in key):
# key would trigger fancy indexing
key = orthogonal_indexer(key, self.shape)
return key
self.array = array

def _ensure_ndarray(self, value):
# We always want the result of indexing to be a NumPy array. If it's
Expand All @@ -474,29 +460,37 @@ def _ensure_ndarray(self, value):
return value

def __getitem__(self, key):
key = self._convert_key(key)
return self._ensure_ndarray(self.array[key])

def __setitem__(self, key, value):
key = self._convert_key(key)
self.array[key] = value


class DaskIndexingAdapter(utils.NDArrayMixin):
"""Wrap a dask array to support orthogonal indexing
"""Wrap a dask array to support broadcasted-indexing.
"""
def __init__(self, array):
self.array = array

def __getitem__(self, key):
key = expanded_indexer(key, self.ndim)
if any(not isinstance(k, integer_types + (slice,)) for k in key):
""" key: tuple of Variable, slice, integer """
# basic or orthogonal indexing
if all(isinstance(k, (integer_types, slice)) or k.squeeze().ndim <= 1
for k in key):
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
if hasattr(subkey, 'squeeze'):
subkey = subkey.squeeze()
if subkey.ndim == 0: # make at least 1-d array
subkey = subkey.flatten()
value = value[(slice(None),) * axis + (subkey,)]
return value
else:
value = self.array[key]
return value
# TODO Dask does not support nd-array indexing.
# flatten() -> .vindex[] -> reshape() should be used
# instead of `.load()`
value = np.asarray(self.array)[key]
return value


class PandasIndexAdapter(utils.NDArrayMixin):
Expand Down
38 changes: 9 additions & 29 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from . import utils
from .pycompat import (basestring, OrderedDict, zip, integer_types,
dask_array_type)
from .indexing import (PandasIndexAdapter, orthogonally_indexable)
from .indexing import (DaskIndexingAdapter, PandasIndexAdapter,
broadcasted_indexable)

import xarray as xr # only for Dataset and DataArray

Expand Down Expand Up @@ -297,7 +298,7 @@ def data(self, data):

@property
def _indexable_data(self):
return orthogonally_indexable(self._data)
return broadcasted_indexable(self._data)

def load(self):
"""Manually trigger loading of this variable's data from disk or a
Expand Down Expand Up @@ -417,7 +418,7 @@ def nonzero(self):
return tuple([as_variable(nz, name=dim) for nz, dim
in zip(nonzeros, self.dims)])

def _isbool(self):
def _isbool_type(self):
""" Return if the variabe is bool or not """
if isinstance(self._data, (np.ndarray, PandasIndexAdapter, pd.Index)):
return self._data.dtype is np.dtype('bool')
Expand All @@ -439,7 +440,7 @@ def _broadcast_indexes_advanced(self, key):
"cannot be used for indexing.")
else:
raise e
if variable._isbool(): # boolean indexing case
if variable._isbool_type(): # boolean indexing case
variables.extend(list(variable.nonzero()))
else:
variables.append(variable)
Expand All @@ -448,21 +449,6 @@ def _broadcast_indexes_advanced(self, key):
key = tuple(variable.data for variable in variables)
return dims, key

def _ensure_array(self, value):
""" For np.ndarray-based-Variable, we always want the result of
indexing to be a NumPy array. If it's not, then it really should be a
0d array. Doing the coercion here instead of inside
variable.as_compatible_data makes it less error prone."""
if isinstance(self._data, np.ndarray):
if not isinstance(value, np.ndarray):
value = utils.to_0d_array(value)
elif isinstance(self._data, dask_array_type):
print(value)
if not isinstance(value, (dask_array_type, dask_array_type)):
value = utils.to_0d_array(value)

return value

def __getitem__(self, key):
"""Return a new Array object whose contents are consistent with
getting the provided key from the underlying data.
Expand All @@ -473,13 +459,7 @@ def __getitem__(self, key):
This method will replace __getitem__ after we make sure its stability.
"""
dims, index_tuple = self._broadcast_indexes(key)
try:
values = self._ensure_array(self._data[index_tuple])
except NotImplementedError:
# TODO temporal implementation.
# Need to wait for dask's nd index support?
values = self._ensure_array(self.load()._data[index_tuple])

values = self._indexable_data[index_tuple]
if hasattr(values, 'ndim'):
assert values.ndim == len(dims), (values.ndim, len(dims))
else:
Expand All @@ -493,15 +473,15 @@ def __setitem__(self, key, value):
See __getitem__ for more details.
"""
key = self._item_key_to_tuple(key)
dims, index_tuple = self._broadcast_indexes(key)
if isinstance(self._data, dask_array_type):
raise TypeError("this variable's data is stored in a dask array, "
'which does not support item assignment. To '
'assign to this variable, you must first load it '
'into memory explicitly using the .load_data() '
'method or accessing its .values attribute.')
data = orthogonally_indexable(self._data)
data[key] = value
data = broadcasted_indexable(self._data)
data[index_tuple] = value

@property
def attrs(self):
Expand Down
19 changes: 10 additions & 9 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,46 +750,47 @@ def test_getitem_advanced(self):
ind = Variable(['a', 'b'], [[0, 1, 1], [1, 1, 0]])
v_new = v[ind]
self.assertTrue(v_new.dims == ('a', 'b', 'y'))
self.assertArrayEqual(v_new, v._data[([0, 1, 1], [1, 1, 0]), :])
self.assertArrayEqual(v_new, v.load()._data[([0, 1, 1], [1, 1, 0]), :])

ind = Variable(['a', 'b'], [[0, 1, 2], [2, 1, 0]])
v_new = v[dict(y=ind)]
self.assertTrue(v_new.dims == ('x', 'a', 'b'))
self.assertArrayEqual(v_new, v._data[:, ([0, 1, 2], [2, 1, 0])])
self.assertArrayEqual(v_new, v.load()._data[:, ([0, 1, 2], [2, 1, 0])])

# with mixed arguments
ind = Variable(['a'], [0, 1])
v_new = v[dict(x=[0, 1], y=ind)]
self.assertTrue(v_new.dims == ('x', 'a'))
self.assertArrayEqual(v_new, v._data[[0, 1]][:, [0, 1]])
self.assertArrayEqual(v_new, v.load()._data[[0, 1]][:, [0, 1]])

ind = Variable(['a', 'b'], [[0, 0], [1, 1]])
v_new = v[dict(x=[1, 0], y=ind)]
self.assertTrue(v_new.dims == ('x', 'a', 'b'))
self.assertArrayEqual(v_new, v._data[[1, 0]][:, ind])
self.assertArrayEqual(v_new, v.load()._data[[1, 0]][:, ind])

# with integer
ind = Variable(['a', 'b'], [[0, 0], [1, 1]])
v_new = v[dict(x=0, y=ind)]
self.assertTrue(v_new.dims == ('a', 'b'))
self.assertArrayEqual(v_new[0], v._data[0][[0, 0]])
self.assertArrayEqual(v_new[1], v._data[0][[1, 1]])
self.assertArrayEqual(v_new[0], v.load()._data[0][[0, 0]])
self.assertArrayEqual(v_new[1], v.load()._data[0][[1, 1]])

# with slice
ind = Variable(['a', 'b'], [[0, 0], [1, 1]])
v_new = v[dict(x=slice(None), y=ind)]
self.assertTrue(v_new.dims == ('x', 'a', 'b'))
self.assertArrayEqual(v_new, v._data[:, [[0, 0], [1, 1]]])
self.assertArrayEqual(v_new, v.load()._data[:, [[0, 0], [1, 1]]])

ind = Variable(['a', 'b'], [[0, 0], [1, 1]])
v_new = v[dict(x=ind, y=slice(None))]
self.assertTrue(v_new.dims == ('a', 'b', 'y'))
self.assertArrayEqual(v_new, v._data[[[0, 0], [1, 1]], :])
self.assertArrayEqual(v_new, v.load()._data[[[0, 0], [1, 1]], :])

ind = Variable(['a', 'b'], [[0, 0], [1, 1]])
v_new = v[dict(x=ind, y=slice(None, 1))]
self.assertTrue(v_new.dims == ('a', 'b', 'y'))
self.assertArrayEqual(v_new, v._data[[[0, 0], [1, 1]], slice(None, 1)])
self.assertArrayEqual(v_new,
v.load()._data[[[0, 0], [1, 1]], slice(None, 1)])

def test_getitem_error(self):
v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]])
Expand Down

0 comments on commit 33c51d3

Please sign in to comment.