Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support __array_ufunc__ for xarray objects. #1962

Merged
merged 11 commits into from
Mar 12, 2018
1 change: 0 additions & 1 deletion .stickler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ linters:
# stickler doesn't support 'exclude' for flake8 properly, so we disable it
# below with files.ignore:
# https://github.com/markstory/lint-review/issues/184
py3k:
files:
ignore:
- doc/**/*.py
5 changes: 2 additions & 3 deletions asv_bench/benchmarks/rolling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd

import xarray as xr

from . import parameterized, randn, requires_dask
Expand Down
7 changes: 7 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ Reshaping and reorganizing
Universal functions
===================

.. warning::

With recent versions of numpy, dask and xarray, NumPy ufuncs are now
supported directly on all xarray and dask objects. This obliviates the need
for the ``xarray.ufuncs`` module, which should not be used for new code
unless compatibility with versions of NumPy prior to v1.13 is required.

This functions are copied from NumPy, but extended to work on NumPy arrays,
dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs``
module:
Expand Down
18 changes: 4 additions & 14 deletions doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,21 +341,15 @@ Datasets support most of the same methods found on data arrays:
ds.mean(dim='x')
abs(ds)

Unfortunately, we currently do not support NumPy ufuncs for datasets [1]_.
:py:meth:`~xarray.Dataset.apply` works around this
limitation, by applying the given function to each variable in the dataset:
Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or
alternatively you can use :py:meth:`~xarray.Dataset.apply` to apply a function
to each variable in a dataset:

.. ipython:: python

np.sin(ds)
ds.apply(np.sin)

You can also use the wrapped functions in the ``xarray.ufuncs`` module:

.. ipython:: python

import xarray.ufuncs as xu
xu.sin(ds)

Datasets also use looping over variables for *broadcasting* in binary
arithmetic. You can do arithmetic between any ``DataArray`` and a dataset:

Expand All @@ -373,10 +367,6 @@ Arithmetic between two datasets matches data variables of the same name:
Similarly to index based alignment, the result has the intersection of all
matching data variables.

.. [1] This was previously due to a limitation of NumPy, but with NumPy 1.13
we should be able to support this by leveraging ``__array_ufunc__``
(:issue:`1617`).

.. _comput.wrapping-custom:

Wrapping custom computation
Expand Down
3 changes: 2 additions & 1 deletion doc/gallery/control_plot_colorbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
Use ``cbar_kwargs`` keyword to specify the number of ticks.
The ``spacing`` kwarg can be used to draw proportional ticks.
"""
import xarray as xr
import matplotlib.pyplot as plt

import xarray as xr

# Load the data
air_temp = xr.tutorial.load_dataset('air_temperature')
air2d = air_temp.air.isel(time=500)
Expand Down
11 changes: 11 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,18 @@ Enhancements
as orthogonal/vectorized indexing, becomes possible for all the backend
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Implemented NumPy's ``__array_ufunc__`` protocol for all xarray objects
(:issue:`1617`). This enables using NumPy ufuncs directly on
``xarray.Dataset`` objects with recent versions of NumPy (v1.13 and newer):

.. ipython:: python

ds = xr.Dataset({'a': 1})
np.sin(ds)

This obliviates the need for the ``xarray.ufuncs`` module, which will be
deprecated in the future when xarray drops support for older versions of
NumPy. By `Stephan Hoyer <https://github.com/shoyer>`_.
- Improve :py:func:`~xarray.DataArray.rolling` logic.
:py:func:`~xarray.DataArrayRolling` object now supports
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view
Expand Down
71 changes: 71 additions & 0 deletions xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Base classes implementing arithmetic for xarray objects."""
from __future__ import absolute_import, division, print_function

import numbers

import numpy as np

from .options import OPTIONS
from .pycompat import bytes_type, dask_array_type, unicode_type
from .utils import not_implemented


class SupportsArithmetic(object):
"""Base class for xarray types that support arithmetic.

Used by Dataset, DataArray, Variable and GroupBy.
"""

# TODO: implement special methods for arithmetic here rather than injecting
# them in xarray/core/ops.py. Ideally, do so by inheriting from
# numpy.lib.mixins.NDArrayOperatorsMixin.

# TODO: allow extending this with some sort of registration system
_HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type,
unicode_type) + dask_array_type

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
from .computation import apply_ufunc

# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
out = kwargs.get('out', ())
for x in inputs + out:
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
return NotImplemented

if ufunc.signature is not None:
raise NotImplementedError(
'{} not supported: xarray objects do not directly implement '
'generalized ufuncs. Instead, use xarray.apply_ufunc.'
.format(ufunc))

if method != '__call__':
# TODO: support other methods, e.g., reduce and accumulate.
raise NotImplementedError(
'{} method for ufunc {} is not implemented on xarray objects, '
'which currently only support the __call__ method.'
.format(method, ufunc))

if any(isinstance(o, SupportsArithmetic) for o in out):
# TODO: implement this with logic like _inplace_binary_op. This
# will be necessary to use NDArrayOperatorsMixin.
raise NotImplementedError(
'xarray objects are not yet supported in the `out` argument '
'for ufuncs.')

join = dataset_join = OPTIONS['arithmetic_join']

return apply_ufunc(ufunc, *inputs,
input_core_dims=((),) * ufunc.nin,
output_core_dims=((),) * ufunc.nout,
join=join,
dataset_join=dataset_join,
dataset_fill_value=np.nan,
kwargs=kwargs,
dask='allowed')

# this has no runtime function - these are listed so IDEs know these
# methods are defined and don't warn on these operations
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \
__or__ = __div__ = __eq__ = __ne__ = not_implemented
15 changes: 7 additions & 8 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import absolute_import, division, print_function

import numbers
import warnings

import numpy as np
import pandas as pd

from . import dtypes, formatting, ops
from .pycompat import OrderedDict, basestring, dask_array_type, suppress
from .arithmetic import SupportsArithmetic
from .options import OPTIONS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is OPTIONS ever used or is it needed in this scope for some reason?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops we don't need it here.

from .pycompat import (
OrderedDict, basestring, bytes_type, dask_array_type, suppress,
unicode_type)
from .utils import Frozen, SortedKeysDict, not_implemented


Expand Down Expand Up @@ -235,7 +240,7 @@ def get_squeeze_dims(xarray_obj, dim, axis=None):
return dim


class BaseDataObject(AttrAccessMixin):
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

W1641 Implementing eq without also implementing hash

"""Shared base class for Dataset and DataArray."""

def squeeze(self, dim=None, drop=False, axis=None):
Expand Down Expand Up @@ -749,12 +754,6 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.close()

# this has no runtime function - these are listed so IDEs know these
# methods are defined and don't warn on these operations
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \
__or__ = __div__ = __eq__ = __ne__ = not_implemented


def full_like(other, fill_value, dtype=None):
"""Return a new object with the same shape and type as a given object.
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function

import numpy as np

from . import nputils

try:
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..plot.plot import _PlotMethods
from .accessors import DatetimeAccessor
from .alignment import align, reindex_like_indexers
from .common import AbstractArray, BaseDataObject
from .common import AbstractArray, DataWithCoords
from .coordinates import (
DataArrayCoordinates, Indexes, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
Expand Down Expand Up @@ -117,7 +117,7 @@ def __setitem__(self, key, value):
_THIS_ARRAY = utils.ReprObject('<this-array>')


class DataArray(AbstractArray, BaseDataObject):
class DataArray(AbstractArray, DataWithCoords):
"""N-dimensional array with labeled coordinates and dimensions.

DataArray provides a wrapper around numpy ndarrays that uses labeled
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
rolling, utils)
from .. import conventions
from .alignment import align
from .common import BaseDataObject, ImplementsDatasetReduce
from .common import DataWithCoords, ImplementsDatasetReduce
from .coordinates import (
DatasetCoordinates, Indexes, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
Expand Down Expand Up @@ -298,7 +298,7 @@ def __getitem__(self, key):
return self.dataset.sel(**key)


class Dataset(Mapping, ImplementsDatasetReduce, BaseDataObject,
class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords,
formatting.ReprMixin):
"""A multi-dimensional, in memory, array database.

Expand Down Expand Up @@ -2362,7 +2362,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None):
array = self._variables[k]
if dim in array.dims:
dims = [d for d in array.dims if d != dim]
count += array.count(dims)
count += np.asarray(array.count(dims))
size += np.prod([self.dims[d] for d in dims])

if thresh is not None:
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

from . import dtypes, duck_array_ops, nputils, ops
from .arithmetic import SupportsArithmetic
from .combine import concat
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
from .pycompat import integer_types, range, zip
Expand Down Expand Up @@ -151,7 +152,7 @@ def _unique_and_monotonic(group):
return index.is_unique and index.is_monotonic


class GroupBy(object):
class GroupBy(SupportsArithmetic):
"""A object that implements the split-apply-combine pattern.

Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/npcompat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import absolute_import, division, print_function

import numpy as np
from distutils.version import LooseVersion

import numpy as np

if LooseVersion(np.__version__) >= LooseVersion('1.12'):
as_strided = np.lib.stride_tricks.as_strided
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import xarray as xr # only for Dataset and DataArray

from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
from . import (
arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils,)
from .indexing import (
BasicIndexer, OuterIndexer, PandasIndexAdapter, VectorizedIndexer,
as_indexable)
Expand Down Expand Up @@ -216,8 +217,8 @@ def _as_array_or_item(data):
return data


class Variable(common.AbstractArray, utils.NdimSizeLenMixin):

class Variable(common.AbstractArray, arithmetic.SupportsArithmetic,
utils.NdimSizeLenMixin):
"""A netcdf-like variable consisting of dimensions, data and attributes
which describe a single Array. A single Variable object is not fully
described outside the context of its parent Dataset (if you want such a
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_nputils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from numpy.testing import assert_array_equal

from xarray.core.nputils import (NumpyVIndexAdapter, _is_contiguous,
rolling_window)
from xarray.core.nputils import (
NumpyVIndexAdapter, _is_contiguous, rolling_window)


def test_is_contiguous():
Expand Down
Loading