From ea96e9b2dc186eacaec33df67c68d716f17bfb18 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski <evgeny.burovskiy@gmail.com> Date: Fri, 22 Nov 2024 11:07:40 +0200 Subject: [PATCH] ENH: allow python scalars in binary elementwise functions Allow func(array, scalar) and func(scalar, array), raise on func(scalar, scalar) if API_VERSION>=2024.12 cross-ref https://github.com/data-apis/array-api/issues/807 To make sure it is all uniform, 1. Generate all binary "ufuncs" in a uniform way, with a decorator 2. Make binary "ufuncs" follow the same logic of the binary operators 3. Reuse the test loop of Array.__binop__ for binary "ufuncs" 4. (minor) in tests, reuse canonical names for dtype categories ("integer or boolean" vs "integer_or_boolean") --- array_api_strict/_array_object.py | 2 + array_api_strict/_elementwise_functions.py | 584 ++++-------------- array_api_strict/_helpers.py | 37 ++ array_api_strict/tests/test_array_object.py | 101 +-- .../tests/test_elementwise_functions.py | 54 +- 5 files changed, 266 insertions(+), 512 deletions(-) create mode 100644 array_api_strict/_helpers.py diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a917441..47153e5 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -230,6 +230,8 @@ def _check_device(self, other): elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + else: + raise TypeError(f"Cannot combine an Array with {type(other)}.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 7c64f67..3c4b3d8 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -10,17 +10,133 @@ _real_numeric_dtypes, _numeric_dtypes, _result_type, + _dtype_categories as _dtype_dtype_categories, ) from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray from ._data_type_functions import broadcast_to, iinfo +from ._helpers import _maybe_normalize_py_scalars from typing import Optional, Union import numpy as np +def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): + """Base implementation of a binary function, `func_name`, defined for + dtypes from `dtype_category` + """ + x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name) + + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np_func(x1._array, x2._array), device=x1.device) + + +_binary_docstring_template=\ +""" +Array API compatible wrapper for :py:func:`np.%s <numpy.%s>`. + +See its docstring for more information. +""" + + +def create_binary_func(func_name, dtype_category, np_func): + def inner(x1: Array, x2: Array, /) -> Array: + return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) + return inner + + +# func_name: dtype_category (must match that from _dtypes.py) +_binary_funcs = { + "add": "numeric", + "atan2": "real floating-point", + "bitwise_and": "integer or boolean", + "bitwise_or": "integer or boolean", + "bitwise_xor": "integer or boolean", + "_bitwise_left_shift": "integer", # leading underscore deliberate + "_bitwise_right_shift": "integer", + # XXX: copysign: real fp or numeric? + "copysign": "real floating-point", + "divide": "floating-point", + "equal": "all", + "greater": "real numeric", + "greater_equal": "real numeric", + "less": "real numeric", + "less_equal": "real numeric", + "not_equal": "all", + "floor_divide": "real numeric", + "hypot": "real floating-point", + "logaddexp": "real floating-point", + "logical_and": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "maximum": "real numeric", + "minimum": "real numeric", + "multiply": "numeric", + "nextafter": "real floating-point", + "pow": "numeric", + "remainder": "real numeric", + "subtract": "numeric", +} + + +# map array-api-name : numpy-name +_numpy_renames = { + "atan2": "arctan2", + "_bitwise_left_shift": "left_shift", + "_bitwise_right_shift": "right_shift", + "pow": "power" +} + + +# create and attach functions to the module +for func_name, dtype_category in _binary_funcs.items(): + # sanity check + assert dtype_category in _dtype_dtype_categories + + numpy_name = _numpy_renames.get(func_name, func_name) + np_func = getattr(np, numpy_name) + + func = create_binary_func(func_name, dtype_category, np_func) + func.__name__ = func_name + + func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + + vars()[func_name] = func + + +copysign = requires_api_version('2023.12')(copysign) # noqa: F821 +hypot = requires_api_version('2023.12')(hypot) # noqa: F821 +maximum = requires_api_version('2023.12')(maximum) # noqa: F821 +minimum = requires_api_version('2023.12')(minimum) # noqa: F821 +nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821 + + +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 + if is_negative: + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") + return _bitwise_left_shift(x1, x2) # noqa: F821 +bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 + + +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 + if is_negative: + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") + return _bitwise_right_shift(x1, x2) # noqa: F821 +bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 + + +# clean up to not pollute the namespace +del func, create_binary_func + + def abs(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`. @@ -56,23 +172,6 @@ def acosh(x: Array, /) -> Array: return Array._new(np.arccosh(x._array), device=x.device) -def add(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.add <numpy.add>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in add") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.add(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def asin(x: Array, /) -> Array: """ @@ -109,23 +208,6 @@ def atan(x: Array, /) -> Array: return Array._new(np.arctan(x._array), device=x.device) -# Note: the function name is different here -def atan2(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in atan2") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.arctan2(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def atanh(x: Array, /) -> Array: """ @@ -138,47 +220,6 @@ def atanh(x: Array, /) -> Array: return Array._new(np.arctanh(x._array), device=x.device) -def bitwise_and(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_and(x1._array, x2._array), device=x1.device) - - -# Note: the function name is different here -def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # Note: bitwise_left_shift is only defined for x2 nonnegative. - if np.any(x2._array < 0): - raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.left_shift(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def bitwise_invert(x: Array, /) -> Array: """ @@ -191,67 +232,6 @@ def bitwise_invert(x: Array, /) -> Array: return Array._new(np.invert(x._array), device=x.device) -def bitwise_or(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_or(x1._array, x2._array), device=x1.device) - - -# Note: the function name is different here -def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # Note: bitwise_right_shift is only defined for x2 nonnegative. - if np.any(x2._array < 0): - raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.right_shift(x1._array, x2._array), device=x1.device) - - -def bitwise_xor(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_xor(x1._array, x2._array), device=x1.device) - - def ceil(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`. @@ -372,6 +352,7 @@ def _isscalar(a): out[ib] = b[ib] return Array._new(out, device=device) + def conj(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.conj <numpy.conj>`. @@ -382,22 +363,6 @@ def conj(x: Array, /) -> Array: raise TypeError("Only complex floating-point dtypes are allowed in conj") return Array._new(np.conj(x._array), device=x.device) -@requires_api_version('2023.12') -def copysign(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.copysign <numpy.copysign>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real numeric dtypes are allowed in copysign") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.copysign(x1._array, x2._array), device=x1.device) def cos(x: Array, /) -> Array: """ @@ -421,36 +386,6 @@ def cosh(x: Array, /) -> Array: return Array._new(np.cosh(x._array), device=x.device) -def divide(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in divide") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.divide(x1._array, x2._array), device=x1.device) - - -def equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.equal(x1._array, x2._array), device=x1.device) - - def exp(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`. @@ -487,69 +422,6 @@ def floor(x: Array, /) -> Array: return Array._new(np.floor(x._array), device=x.device) -def floor_divide(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in floor_divide") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.floor_divide(x1._array, x2._array), device=x1.device) - - -def greater(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in greater") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater(x1._array, x2._array), device=x1.device) - - -def greater_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in greater_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater_equal(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def hypot(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.hypot <numpy.hypot>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in hypot") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.hypot(x1._array, x2._array), device=x1.device) - def imag(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.imag <numpy.imag>`. @@ -594,38 +466,6 @@ def isnan(x: Array, /) -> Array: return Array._new(np.isnan(x._array), device=x.device) -def less(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.less <numpy.less>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in less") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less(x1._array, x2._array), device=x1.device) - - -def less_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in less_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less_equal(x1._array, x2._array), device=x1.device) - - def log(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log <numpy.log>`. @@ -670,38 +510,6 @@ def log10(x: Array, /) -> Array: return Array._new(np.log10(x._array), device=x.device) -def logaddexp(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in logaddexp") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logaddexp(x1._array, x2._array), device=x1.device) - - -def logical_and(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_and") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_and(x1._array, x2._array), device=x1.device) - - def logical_not(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`. @@ -713,87 +521,6 @@ def logical_not(x: Array, /) -> Array: return Array._new(np.logical_not(x._array), device=x.device) -def logical_or(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_or") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_or(x1._array, x2._array), device=x1.device) - - -def logical_xor(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_xor") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_xor(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def maximum(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.maximum <numpy.maximum>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in maximum") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error - # in that case? - return Array._new(np.maximum(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def minimum(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.minimum <numpy.minimum>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in minimum") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.minimum(x1._array, x2._array), device=x1.device) - -def multiply(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in multiply") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.multiply(x1._array, x2._array), device=x1.device) - - def negative(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`. @@ -805,34 +532,6 @@ def negative(x: Array, /) -> Array: return Array._new(np.negative(x._array), device=x.device) -@requires_api_version('2024.12') -def nextafter(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.nextafter <numpy.nextafter>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in nextafter") - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.nextafter(x1._array, x2._array), device=x1.device) - -def not_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.not_equal(x1._array, x2._array), device=x1.device) - - def positive(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`. @@ -844,23 +543,6 @@ def positive(x: Array, /) -> Array: return Array._new(np.positive(x._array), device=x.device) -# Note: the function name is different here -def pow(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.power <numpy.power>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in pow") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.power(x1._array, x2._array), device=x1.device) - - def real(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.real <numpy.real>`. @@ -883,22 +565,6 @@ def reciprocal(x: Array, /) -> Array: raise TypeError("Only floating-point dtypes are allowed in reciprocal") return Array._new(np.reciprocal(x._array), device=x.device) -def remainder(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in remainder") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.remainder(x1._array, x2._array), device=x1.device) - - def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round <numpy.round>`. @@ -979,22 +645,6 @@ def sqrt(x: Array, /) -> Array: return Array._new(np.sqrt(x._array), device=x.device) -def subtract(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in subtract") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.subtract(x1._array, x2._array), device=x1.device) - - def tan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`. diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py new file mode 100644 index 0000000..2258d29 --- /dev/null +++ b/array_api_strict/_helpers.py @@ -0,0 +1,37 @@ +"""Private helper routines. +""" + +from ._flags import get_array_api_strict_flags +from ._dtypes import _dtype_categories + +_py_scalars = (bool, int, float, complex) + + +def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): + + flags = get_array_api_strict_flags() + if flags["api_version"] < "2024.12": + # scalars will fail at the call site + return x1, x2 + + _allowed_dtypes = _dtype_categories[dtype_category] + + if isinstance(x1, _py_scalars): + if isinstance(x2, _py_scalars): + raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}") + # x2 must be an array + if x2.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.") + x1 = x2._promote_scalar(x1) + + elif isinstance(x2, _py_scalars): + # x1 must be an array + if x1.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.") + x2 = x1._promote_scalar(x2) + else: + if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. " + f"Got {x1.dtype} and {x2.dtype}.") + return x1, x2 + diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 8f185f0..4535d99 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -96,12 +96,60 @@ def test_promoted_scalar_inherits_device(): assert y.device == device1 + +BIG_INT = int(1e30) + +def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for real floating-point array dtypes + # - a Python int, float, or complex for complex floating-point array dtypes + + if ((dtypes == "all" + or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes + or dtypes == "integer" and a.dtype in _integer_dtypes + or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes + or dtypes == "boolean" and a.dtype in _boolean_dtypes + or dtypes == "floating-point" and a.dtype in _floating_dtypes + or dtypes == "real floating-point" and a.dtype in _real_floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and (a.dtype in _boolean_dtypes and type(s) == bool + or a.dtype in _integer_dtypes and type(s) == int + or a.dtype in _real_floating_dtypes and type(s) in [float, int] + or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] + )): + if a.dtype in _integer_dtypes and s == BIG_INT: + with assert_raises(OverflowError): + func(s) + return False + + else: + # Only test for no error + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + func(s) + return True + + else: + with assert_raises(TypeError): + func(s) + return False + + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", - "__and__": "integer_or_boolean", + "__and__": "integer or boolean", "__eq__": "all", "__floordiv__": "real numeric", "__ge__": "real numeric", @@ -112,12 +160,12 @@ def test_operators(): "__mod__": "real numeric", "__mul__": "numeric", "__ne__": "all", - "__or__": "integer_or_boolean", + "__or__": "integer or boolean", "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", - "__truediv__": "floating", - "__xor__": "integer_or_boolean", + "__truediv__": "floating-point", + "__xor__": "integer or boolean", } # Recompute each time because of in-place ops def _array_vals(): @@ -128,8 +176,6 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) - - BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: @@ -139,40 +185,7 @@ def _array_vals(): for s in [1, 1.0, 1j, BIG_INT, False]: for _op in ops: for a in _array_vals(): - # Test array op scalar. From the spec, the following combinations - # are supported: - - # - Python bool for a bool array dtype, - # - a Python int within the bounds of the given dtype for integer array dtypes, - # - a Python int or float for real floating-point array dtypes - # - a Python int, float, or complex for complex floating-point array dtypes - - if ((dtypes == "all" - or dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes - or dtypes == "integer" and a.dtype in _integer_dtypes - or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes - or dtypes == "boolean" and a.dtype in _boolean_dtypes - or dtypes == "floating" and a.dtype in _floating_dtypes - ) - # bool is a subtype of int, which is why we avoid - # isinstance here. - and (a.dtype in _boolean_dtypes and type(s) == bool - or a.dtype in _integer_dtypes and type(s) == int - or a.dtype in _real_floating_dtypes and type(s) in [float, int] - or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] - )): - if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s)) - else: - # Only test for no error - with suppress_warnings() as sup: - # ignore warnings from pow(BIG_INT) - sup.filter(RuntimeWarning, - "invalid value encountered in power") - getattr(a, _op)(s) - else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + _check_op_array_scalar(dtypes, a, s, getattr(a, _op), _op) # Test array op array. for _op in ops: @@ -203,10 +216,10 @@ def _array_vals(): or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes - or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes + or dtypes == "integer or boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes - or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes + or dtypes == "floating-point" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes ): getattr(x, _op)(y) else: @@ -214,7 +227,7 @@ def _array_vals(): unary_op_dtypes = { "__abs__": "numeric", - "__invert__": "integer_or_boolean", + "__invert__": "integer or boolean", "__neg__": "numeric", "__pos__": "numeric", } @@ -223,7 +236,7 @@ def _array_vals(): if ( dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "integer_or_boolean" + or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes ): # Only test for no error diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 4e1b9cc..0b90f0b 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,7 @@ from inspect import signature, getmodule -from numpy.testing import assert_raises +from pytest import raises as assert_raises +from numpy.testing import suppress_warnings import pytest @@ -19,6 +20,8 @@ ) from .._flags import set_array_api_strict_flags +from .test_array_object import _check_op_array_scalar, BIG_INT + import array_api_strict @@ -120,6 +123,7 @@ def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] + mod_funcs = [n for n in mod_funcs if not n.startswith("_")] assert set(mod_funcs) == set(elementwise_function_input_types) @@ -202,3 +206,51 @@ def test_bitwise_shift_error(): assert_raises( ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) ) + + + +def test_scalars(): + # mirror test_array_object.py::test_operators() + # + # Also check that binary functions accept (array, scalar) and (scalar, array) + # arguments, and reject (scalar, scalar) arguments. + + # Use the latest version of the standard so that scalars are actually allowed + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") + + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1.0, dtype=d) + + + for func_name, dtypes in elementwise_function_input_types.items(): + func = getattr(_elementwise_functions, func_name) + if nargs(func) != 2: + continue + + for s in [1, 1.0, 1j, BIG_INT, False]: + for a in _array_vals(): + for func1 in [lambda s: func(a, s), lambda s: func(s, a)]: + allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) + + # only check `func(array, scalar) == `func(array, array)` if + # the former is legal under the promotion rules + if allowed: + conv_scalar = asarray(s, dtype=a.dtype) + + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + assert func(s, a) == func(conv_scalar, a) + assert func(a, s) == func(a, conv_scalar) + + with assert_raises(TypeError): + func(s, s) + +