Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional numpy __array_function__ implementations #6

Closed
wants to merge 10 commits into from
8 changes: 3 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ branches:
env:
# Should pandas tests be removed or replaced wih import checks?
#- UNCERTAINTIES="N" PYTHON="3.6" NUMPY_VERSION=1.14 PANDAS=1
- UNCERTAINTIES="N" PYTHON="2.7" NUMPY_VERSION=1.17 PANDAS=0
- UNCERTAINTIES="N" PYTHON="3.6" NUMPY_VERSION=1.17 PANDAS=0
- UNCERTAINTIES="N" PYTHON="3.7" NUMPY_VERSION=1.17 PANDAS=0
- UNCERTAINTIES="N" PYTHON="3.6" NUMPY_VERSION=1.16 PANDAS=0
- UNCERTAINTIES="N" PYTHON="3.6" NUMPY_VERSION=1.11.2 PANDAS=0
# - UNCERTAINTIES="N" PYTHON="3.3" NUMPY_VERSION=1.9.2 PANDAS=0
# - UNCERTAINTIES="N" PYTHON="3.4" NUMPY_VERSION=1.11.2 PANDAS=0
# - UNCERTAINTIES="N" PYTHON="3.5" NUMPY_VERSION=1.11.2 PANDAS=0
Expand Down Expand Up @@ -42,11 +45,6 @@ before_install:
fi
# Useful for debugging any issues with conda
- conda info -a

# The next couple lines fix a crash with multiprocessing on Travis and are not specific to using Miniconda
- sudo rm -rf /dev/shm
- sudo ln -s /run/shm /dev/shm

- export ENV_NAME=travis

install:
Expand Down
33 changes: 25 additions & 8 deletions pint/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,22 @@ def u(x):
# TODO: remove this warning after v0.10
class BehaviorChangeWarning(UserWarning):
pass
_msg = ('The way pint handles numpy operations has changed. '
'Unimplemented numpy operations will now fail instead '
'of making assumptions about units. Some functions, '
'eg concat, will now return Quanties with units, where '
'they returned ndarrays previously. See '
'https://github.com/hgrecco/pint/pull/764 . '
_msg = ('The way pint handles numpy operations has changed with '
'the implementation of NEP 18. Unimplemented numpy operations '
'will now fail instead of making assumptions about units. Some '
'functions, eg concat, will now return Quanties with units, '
'where they returned ndarrays previously. See '
'https://github.com/hgrecco/pint/pull/764. '
'To hide this warning use the following code to import pint:'
"""

import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import pint

To disable the new behavior, see
https://www.numpy.org/neps/nep-0018-array-function-protocol.html#implementation
---
""")

Expand All @@ -107,10 +110,23 @@ def _to_magnitude(value, force_ndarray=False):
if force_ndarray:
return np.asarray(value)
return value

warnings.warn(_msg, BehaviorChangeWarning)

def _test_array_function_protocol():
# Test if the __array_function__ protocol is enabled
try:
class FakeArray:
def __array_function__(self, *args, **kwargs):
return

np.concatenate([FakeArray()])
return True
except ValueError:
return False

HAS_NUMPY_ARRAY_FUNCTION = _test_array_function_protocol()

if HAS_NUMPY_ARRAY_FUNCTION:
warnings.warn(_msg, BehaviorChangeWarning)

except ImportError:

Expand All @@ -122,6 +138,7 @@ class ndarray(object):
HAS_NUMPY = False
NUMPY_VER = '0'
NUMERIC_TYPES = (Number, Decimal)
HAS_NUMPY_ARRAY_FUNCTION = False

def _to_magnitude(value, force_ndarray=False):
if isinstance(value, (dict, bool)) or value is None:
Expand Down
78 changes: 59 additions & 19 deletions pint/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def wrapped(self, *args, **kwargs):
def check_implemented(f):
def wrapped(self, *args, **kwargs):
other=args[0]
if other.__class__.__name__ in ["PintArray", "Series"]:
if other.__class__.__name__ in ["PintArray", "Series", "DataArray"]:
return NotImplemented
# pandas often gets to arrays of quantities [ Q_(1,"m"), Q_(2,"m")]
# and expects Quantity * array[Quantity] should return NotImplemented
Expand Down Expand Up @@ -110,7 +110,6 @@ def convert_to_consistent_units(pre_calc_units=None, *args, **kwargs):
"""Takes the args for a numpy function and converts any Quantity or Sequence of Quantities
into the units of the first Quantiy/Sequence of quantities. Other args are left untouched.
"""
print(args,kwargs)
def convert_arg(arg):
if pre_calc_units is not None:
if isinstance(arg,BaseQuantity):
Expand All @@ -126,7 +125,6 @@ def convert_arg(arg):

new_args=tuple(convert_arg(arg) for arg in args)
new_kwargs = {key:convert_arg(arg) for key,arg in kwargs.items()}
print( new_args, new_kwargs)
return new_args, new_kwargs

def implement_func(func_str, pre_calc_units_, post_calc_units_, out_units_):
Expand All @@ -153,12 +151,10 @@ def implement_func(func_str, pre_calc_units_, post_calc_units_, out_units_):

"""
func = getattr(np,func_str)
print(func_str)

@implements(func)
def _(*args, **kwargs):
# TODO make work for kwargs
print("_",func_str)
args_and_kwargs = list(args)+list(kwargs.values())

(pre_calc_units, post_calc_units, out_units)=(pre_calc_units_, post_calc_units_, out_units_)
Expand All @@ -176,6 +172,8 @@ def _(*args, **kwargs):
return res
elif post_calc_units == "as_pre_calc":
post_calc_units = pre_calc_units
elif post_calc_units == "sum":
post_calc_units = (1*first_input_units + 1*first_input_units).units
elif post_calc_units == "prod":
product = 1
for x in args_and_kwargs:
Expand All @@ -193,7 +191,8 @@ def _(*args, **kwargs):
for x in args_and_kwargs[1:]:
product /= x
post_calc_units = product.units
print(post_calc_units)
elif post_calc_units == "variance":
post_calc_units = ((1*first_input_units + 1*first_input_units)**2).units
Q_ = first_input_units._REGISTRY.Quantity
post_calc_Q_= Q_(res, post_calc_units)

Expand All @@ -202,30 +201,72 @@ def _(*args, **kwargs):
elif out_units == "infer_from_input":
out_units = first_input_units
return post_calc_Q_.to(out_units)

@implements(np.power)
def _power(*args, **kwargs):
print(args)
pass
for func_str in ['linspace', 'concatenate', 'block', 'stack', 'hstack', 'vstack', 'dstack', 'atleast_1d', 'column_stack', 'atleast_2d', 'atleast_3d', 'expand_dims','squeeze', 'swapaxes', 'compress', 'searchsorted' ,'rollaxis', 'broadcast_to', 'moveaxis', 'fix']:

@implements(np.meshgrid)
def _meshgrid(*xi, **kwargs):
# Simply need to map input units to onto list of outputs
input_units = (x.units for x in xi)
res = np.meshgrid(*(x.m for x in xi), **kwargs)
return [out * unit for out, unit in zip(res, input_units)]

@implements(np.full_like)
def _full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):
# Make full_like by multiplying with array from ones_like in a
# non-multiplicative-unit-safe way
if isinstance(fill_value, BaseQuantity):
return fill_value._REGISTRY.Quantity(
np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape) * fill_value.m,
fill_value.units)
else:
return (np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
* fill_value)

@implements(np.interp)
def _interp(x, xp, fp, left=None, right=None, period=None):
# Need to handle x and y units separately
x_unit = _get_first_input_units([x, xp, period])
y_unit = _get_first_input_units([fp, left, right])
x_args, _ = convert_to_consistent_units(x_unit, x, xp, period)
y_args, _ = convert_to_consistent_units(y_unit, fp, left, right)
x, xp, period = x_args
fp, right, left = y_args
Q_ = y_unit._REGISTRY.Quantity
return Q_(np.interp(x, xp, fp, left=left, right=right, period=period), y_unit)

for func_str in ['linspace', 'concatenate', 'block', 'stack', 'hstack', 'vstack', 'dstack', 'atleast_1d', 'column_stack', 'atleast_2d', 'atleast_3d', 'expand_dims','squeeze', 'swapaxes', 'compress', 'rollaxis', 'broadcast_to', 'moveaxis', 'fix', 'amax', 'amin', 'nanmax', 'nanmin', 'around', 'diagonal', 'mean', 'ptp', 'ravel', 'round_', 'sort', 'median', 'nanmedian', 'transpose', 'flip', 'copy', 'trim_zeros', 'append', 'clip', 'nan_to_num']:
implement_func(func_str, 'consistent_infer', 'as_pre_calc', 'as_post_calc')


for func_str in ['isclose', 'searchsorted']:
implement_func(func_str, 'consistent_infer', None, None)

for func_str in ['unwrap']:
implement_func(func_str, 'rad', 'rad', 'infer_from_input')


for func_str in ['size', 'isreal', 'iscomplex']:
for func_str in ['cumprod', 'cumproduct', 'nancumprod']:
implement_func(func_str, 'dimensionless', 'dimensionless', 'infer_from_input')

for func_str in ['size', 'isreal', 'iscomplex', 'shape', 'ones_like', 'zeros_like', 'empty_like', 'argsort', 'argmin', 'argmax', 'alen', 'ndim', 'nanargmax', 'nanargmin', 'count_nonzero', 'nonzero', 'result_type']:
implement_func(func_str, None, None, None)

for func_str in ['average', 'mean', 'std', 'nanmean', 'nanstd', 'sum', 'nansum', 'cumsum', 'nancumsum']:
implement_func(func_str, None, 'sum', None)

for func_str in ['cross', 'trapz']:
for func_str in ['cross', 'trapz', 'dot']:
implement_func(func_str, None, 'prod', None)

for func_str in ['diff', 'ediff1d',]:
implement_func(func_str, None, 'delta', None)

for func_str in ['gradient', ]:
implement_func(func_str, None, 'delta,div', None)


for func_str in ['var', 'nanvar']:
implement_func(func_str, None, 'variance', None)


@contextlib.contextmanager
def printoptions(*args, **kwargs):
Expand All @@ -252,7 +293,6 @@ class BaseQuantity(PrettyIPython, SharedRegistryObject):
:type units: UnitsContainer, str or Quantity.
"""
def __array_function__(self, func, types, args, kwargs):
print("__array_function__", func)
if func not in HANDLED_FUNCTIONS:
return NotImplemented
if not all(issubclass(t, BaseQuantity) for t in types):
Expand Down Expand Up @@ -1380,7 +1420,7 @@ def __ne__(self, other):

@check_implemented
def compare(self, other, op):
if not isinstance(other, self.__class__):
if not isinstance(other, BaseQuantity):
if self.dimensionless:
return op(self._convert_magnitude_not_inplace(UnitsContainer()), other)
elif _eq(other, 0, True):
Expand Down Expand Up @@ -1500,10 +1540,10 @@ def shape(self, value):
self._magnitude.shape = value

def searchsorted(self, v, side='left', sorter=None):
if isinstance(v, self.__class__):
if isinstance(v, BaseQuantity):
v = v.to(self).magnitude
elif self.dimensionless:
v = self.__class__(v, '').to(self)
v = Quantity(v, '').to(self)
else:
raise DimensionalityError('dimensionless', self._units)
return self.magnitude.searchsorted(v, side)
Expand Down
15 changes: 14 additions & 1 deletion pint/testsuite/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import re
import unittest

from pint.compat import HAS_NUMPY, HAS_PROPER_BABEL, HAS_UNCERTAINTIES, NUMPY_VER, PYTHON3
from pint.compat import HAS_NUMPY, HAS_PROPER_BABEL, HAS_UNCERTAINTIES, NUMPY_VER, PYTHON3, HAS_NUMPY_ARRAY_FUNCTION


def requires_array_function_protocol():
if not HAS_NUMPY:
return unittest.skip('Requires NumPy')
return unittest.skipUnless(HAS_NUMPY_ARRAY_FUNCTION, 'Requires __array_function__ protocol to be enabled')


def requires_not_array_function_protocol():
if not HAS_NUMPY:
return unittest.skip('Requires NumPy')
return unittest.skipIf(HAS_NUMPY_ARRAY_FUNCTION, 'Requires __array_function__ protocol to be unavailable or disabled')


def requires_numpy18():
if not HAS_NUMPY:
Expand Down
Loading