Skip to content

Commit

Permalink
Support orthogonal indexing in MemoryCachedArray (Fix for #1429) (#1676)
Browse files Browse the repository at this point in the history
* Added as_indexable

* Make NioArrayWrapper indexable.

* Made ArrayWrapper raise NotImplementedError for VecotorizedIndexer.
Added test for DaskIndexingAdaptor

* Support pep8

* whats new

* Adopt assert_identical than self.assertDatasetIdentical

* Explicitly test vindex support of backend array wrappers.

* Change NDArrayIndexable to Mixin

* Add a validation in `as_compatible_data`

* Make sanity check clean. Fix misspellings.

* Revert an unintended change test_backends.py

* Rename SUPPORT_ARRAY_TYPES -> NON_NUMPY_SUPPORTED_ARRAY_TYPES
  • Loading branch information
fujiisoup authored and shoyer committed Nov 6, 2017
1 parent 6b16e6e commit 2a1d392
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 57 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Bug fixes
of ``.T`` attributes (:issue:`1675`).
By `Keisuke Fujii <https://github.com/fujiisoup>`_

- (Internal bug) MemoryCachedArray now supports the orthogonal indexing.
Also made some internal cleanups around array wrappers (:issue:`1429`).

- Fix two bugs that were preventing dask arrays from being specified as
coordinates in the DataArray constructor (:issue:`1684`).
By `Joe Hamman <https://github.com/jhamman>`_
Expand Down
4 changes: 4 additions & 0 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

class H5NetCDFArrayWrapper(BaseNetCDF4Array):
def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)
with self.datastore.ensure_open(autoclose=True):
return self.get_array()[key]
Expand Down
8 changes: 7 additions & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
'|': 'native'}


class BaseNetCDF4Array(NdimSizeLenMixin, DunderArrayMixin):
class BaseNetCDF4Array(NdimSizeLenMixin, DunderArrayMixin,
indexing.NDArrayIndexable):
def __init__(self, variable_name, datastore):
self.datastore = datastore
self.variable_name = variable_name
Expand All @@ -50,6 +51,11 @@ def get_array(self):

class NetCDF4ArrayWrapper(BaseNetCDF4Array):
def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))

key = indexing.to_tuple(key)

if self.datastore.is_remote: # pragma: no cover
Expand Down
6 changes: 5 additions & 1 deletion xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .common import AbstractDataStore, robust_getitem


class PydapArrayWrapper(NDArrayMixin):
class PydapArrayWrapper(NDArrayMixin, indexing.NDArrayIndexable):
def __init__(self, array):
self.array = array

Expand All @@ -27,6 +27,10 @@ def dtype(self):
return np.dtype(t.typecode + str(t.size))

def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)
if not isinstance(key, tuple):
key = (key,)
Expand Down
7 changes: 6 additions & 1 deletion xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .common import AbstractDataStore, DataStorePickleMixin


class NioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin):
class NioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin,
indexing.NDArrayIndexable):

def __init__(self, variable_name, datastore):
self.datastore = datastore
Expand All @@ -28,6 +29,10 @@ def get_array(self):
return self.datastore.ds.variables[self.variable_name]

def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)
with self.datastore.ensure_open(autoclose=True):
array = self.get_array()
Expand Down
9 changes: 6 additions & 3 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from collections import OrderedDict
from distutils.version import LooseVersion
import numpy as np

from .. import DataArray
Expand All @@ -18,7 +17,8 @@
'first.')


class RasterioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin):
class RasterioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin,
indexing.NDArrayIndexable):
"""A wrapper around rasterio dataset objects"""
def __init__(self, rasterio_ds):
self.rasterio_ds = rasterio_ds
Expand All @@ -38,8 +38,11 @@ def shape(self):
return self._shape

def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)

# bands cannot be windowed but they can be listed
band_key = key[0]
n_bands = self.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..core.pycompat import iteritems, OrderedDict, basestring
from ..core.utils import (Frozen, FrozenOrderedDict, NdimSizeLenMixin,
DunderArrayMixin)
from ..core.indexing import NumpyIndexingAdapter
from ..core.indexing import NumpyIndexingAdapter, NDArrayIndexable

from .common import WritableCFDataStore, DataStorePickleMixin
from .netcdf3 import (is_valid_nc3_name, encode_nc3_attr_value,
Expand All @@ -31,7 +31,7 @@ def _decode_attrs(d):
for (k, v) in iteritems(d))


class ScipyArrayWrapper(NdimSizeLenMixin, DunderArrayMixin):
class ScipyArrayWrapper(NdimSizeLenMixin, DunderArrayMixin, NDArrayIndexable):

def __init__(self, variable_name, datastore):
self.datastore = datastore
Expand Down
32 changes: 16 additions & 16 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def encode_cf_timedelta(timedeltas, units=None):
return (num, units)


class MaskedAndScaledArray(utils.NDArrayMixin):
class MaskedAndScaledArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically scaled and masked according to
CF conventions for packed and missing data values.
Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, array, fill_value=None, scale_factor=None,
After applying scale_factor, add this number to entries in the
original array.
"""
self.array = array
self.array = indexing.as_indexable(array)
self.fill_value = fill_value
self.scale_factor = scale_factor
self.add_offset = add_offset
Expand All @@ -391,13 +391,13 @@ def __repr__(self):
self.scale_factor, self.add_offset, self._dtype))


class DecodedCFDatetimeArray(utils.NDArrayMixin):
class DecodedCFDatetimeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically converted into datetime objects
using decode_cf_datetime.
"""
def __init__(self, array, units, calendar=None):
self.array = array
self.array = indexing.as_indexable(array)
self.units = units
self.calendar = calendar

Expand Down Expand Up @@ -430,13 +430,13 @@ def __getitem__(self, key):
calendar=self.calendar)


class DecodedCFTimedeltaArray(utils.NDArrayMixin):
class DecodedCFTimedeltaArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically converted into timedelta objects
using decode_cf_timedelta.
"""
def __init__(self, array, units):
self.array = array
self.array = indexing.as_indexable(array)
self.units = units

@property
Expand All @@ -447,7 +447,7 @@ def __getitem__(self, key):
return decode_cf_timedelta(self.array[key], units=self.units)


class StackedBytesArray(utils.NDArrayMixin):
class StackedBytesArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically stacked along the last dimension.
Expand All @@ -465,7 +465,7 @@ def __init__(self, array):
if array.dtype != 'S1':
raise ValueError(
"can only use StackedBytesArray if argument has dtype='S1'")
self.array = array
self.array = indexing.as_indexable(array)

@property
def dtype(self):
Expand Down Expand Up @@ -493,7 +493,7 @@ def __getitem__(self, key):
return char_to_bytes(self.array[key])


class BytesToStringArray(utils.NDArrayMixin):
class BytesToStringArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Wrapper that decodes bytes to unicode when values are read.
>>> BytesToStringArray(np.array([b'abc']))[:]
Expand All @@ -509,7 +509,7 @@ def __init__(self, array, encoding='utf-8'):
encoding : str
String encoding to use.
"""
self.array = array
self.array = indexing.as_indexable(array)
self.encoding = encoding

@property
Expand All @@ -532,7 +532,7 @@ def __getitem__(self, key):
return decode_bytes_array(self.array[key], self.encoding)


class NativeEndiannessArray(utils.NDArrayMixin):
class NativeEndiannessArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Decode arrays on the fly from non-native to native endianness
This is useful for decoding arrays from netCDF3 files (which are all
Expand All @@ -551,7 +551,7 @@ class NativeEndiannessArray(utils.NDArrayMixin):
dtype('int16')
"""
def __init__(self, array):
self.array = array
self.array = indexing.as_indexable(array)

@property
def dtype(self):
Expand All @@ -561,7 +561,7 @@ def __getitem__(self, key):
return np.asarray(self.array[key], dtype=self.dtype)


class BoolTypeArray(utils.NDArrayMixin):
class BoolTypeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Decode arrays on the fly from integer to boolean datatype
This is useful for decoding boolean arrays from integer typed netCDF
Expand All @@ -579,7 +579,7 @@ class BoolTypeArray(utils.NDArrayMixin):
dtype('bool')
"""
def __init__(self, array):
self.array = array
self.array = indexing.as_indexable(array)

@property
def dtype(self):
Expand All @@ -589,7 +589,7 @@ def __getitem__(self, key):
return np.asarray(self.array[key], dtype=self.dtype)


class UnsignedIntTypeArray(utils.NDArrayMixin):
class UnsignedIntTypeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
"""Decode arrays on the fly from signed integer to unsigned
integer. Typically used when _Unsigned is set at as a netCDF
attribute on a signed integer variable.
Expand All @@ -606,7 +606,7 @@ class UnsignedIntTypeArray(utils.NDArrayMixin):
array([ 0, 1, 127, 128, 255], dtype=uint8)
"""
def __init__(self, array):
self.array = array
self.array = indexing.as_indexable(array)
self.unsigned_dtype = np.dtype('u%s' % array.dtype.itemsize)

@property
Expand Down
39 changes: 28 additions & 11 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,11 @@ class VectorizedIndexer(IndexerTuple):
""" Tuple for vectorized indexing """


class LazilyIndexedArray(utils.NDArrayMixin):
class NDArrayIndexable(object):
""" Mixin to mark support for IndexerTuple subclasses in indexing."""


class LazilyIndexedArray(utils.NDArrayMixin, NDArrayIndexable):
"""Wrap an array that handles orthogonal indexing to make indexing lazy
"""
def __init__(self, array, key=None):
Expand Down Expand Up @@ -356,7 +360,7 @@ def shape(self):
return tuple(shape)

def __array__(self, dtype=None):
array = xarray_indexable(self.array)
array = as_indexable(self.array)
return np.asarray(array[self.key], dtype=None)

def __getitem__(self, key):
Expand All @@ -379,9 +383,9 @@ def _wrap_numpy_scalars(array):
return array


class CopyOnWriteArray(utils.NDArrayMixin):
class CopyOnWriteArray(utils.NDArrayMixin, NDArrayIndexable):
def __init__(self, array):
self.array = array
self.array = as_indexable(array)
self._copied = False

def _ensure_copied(self):
Expand All @@ -400,9 +404,9 @@ def __setitem__(self, key, value):
self.array[key] = value


class MemoryCachedArray(utils.NDArrayMixin):
class MemoryCachedArray(utils.NDArrayMixin, NDArrayIndexable):
def __init__(self, array):
self.array = _wrap_numpy_scalars(array)
self.array = _wrap_numpy_scalars(as_indexable(array))

def _ensure_cached(self):
if not isinstance(self.array, np.ndarray):
Expand All @@ -419,14 +423,21 @@ def __setitem__(self, key, value):
self.array[key] = value


def xarray_indexable(array):
def as_indexable(array):
"""
This function always returns a NDArrayIndexable subclass,
so that the vectorized indexing is always possible with the returned
object.
"""
if isinstance(array, NDArrayIndexable):
return array
if isinstance(array, np.ndarray):
return NumpyIndexingAdapter(array)
if isinstance(array, pd.Index):
return PandasIndexAdapter(array)
if isinstance(array, dask_array_type):
return DaskIndexingAdapter(array)
return array
raise TypeError('Invalid array type: {}'.format(type(array)))


def _outer_to_numpy_indexer(key, shape):
Expand Down Expand Up @@ -467,10 +478,14 @@ def _outer_to_numpy_indexer(key, shape):
return tuple(new_key)


class NumpyIndexingAdapter(utils.NDArrayMixin):
class NumpyIndexingAdapter(utils.NDArrayMixin, NDArrayIndexable):
"""Wrap a NumPy array to use broadcasted indexing
"""
def __init__(self, array):
# In NumpyIndexingAdapter we only allow to store bare np.ndarray
if not isinstance(array, np.ndarray):
raise TypeError('NumpyIndexingAdapter only wraps np.ndarray. '
'Trying to wrap {}'.format(type(array)))
self.array = array

def _ensure_ndarray(self, value):
Expand Down Expand Up @@ -502,7 +517,7 @@ def __setitem__(self, key, value):
array[key] = value


class DaskIndexingAdapter(utils.NDArrayMixin):
class DaskIndexingAdapter(utils.NDArrayMixin, NDArrayIndexable):
"""Wrap a dask array to support xarray-style indexing.
"""
def __init__(self, array):
Expand All @@ -522,6 +537,8 @@ def to_int_tuple(key):
return self.array[to_int_tuple(key)]
elif isinstance(key, VectorizedIndexer):
return self.array.vindex[to_int_tuple(tuple(key))]
elif key is Ellipsis:
return self.array
else:
assert isinstance(key, OuterIndexer)
key = to_int_tuple(tuple(key))
Expand All @@ -543,7 +560,7 @@ def __setitem__(self, key, value):
'method or accessing its .values attribute.')


class PandasIndexAdapter(utils.NDArrayMixin):
class PandasIndexAdapter(utils.NDArrayMixin, NDArrayIndexable):
"""Wrap a pandas.Index to be better about preserving dtypes and to handle
indexing by length 1 tuples like numpy
"""
Expand Down
Loading

0 comments on commit 2a1d392

Please sign in to comment.