From ea96e9b2dc186eacaec33df67c68d716f17bfb18 Mon Sep 17 00:00:00 2001
From: Evgeni Burovski <evgeny.burovskiy@gmail.com>
Date: Fri, 22 Nov 2024 11:07:40 +0200
Subject: [PATCH] ENH: allow python scalars in binary elementwise functions

Allow func(array, scalar) and func(scalar, array), raise on
func(scalar, scalar) if API_VERSION>=2024.12

cross-ref https://github.com/data-apis/array-api/issues/807

To make sure it is all uniform,

1. Generate all binary "ufuncs" in a uniform way, with a decorator
2. Make binary "ufuncs" follow the same logic of the binary operators
3. Reuse the test loop of Array.__binop__ for binary "ufuncs"
4. (minor) in tests, reuse canonical names for dtype categories
   ("integer or boolean" vs "integer_or_boolean")
---
 array_api_strict/_array_object.py             |   2 +
 array_api_strict/_elementwise_functions.py    | 584 ++++--------------
 array_api_strict/_helpers.py                  |  37 ++
 array_api_strict/tests/test_array_object.py   | 101 +--
 .../tests/test_elementwise_functions.py       |  54 +-
 5 files changed, 266 insertions(+), 512 deletions(-)
 create mode 100644 array_api_strict/_helpers.py

diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py
index a917441..47153e5 100644
--- a/array_api_strict/_array_object.py
+++ b/array_api_strict/_array_object.py
@@ -230,6 +230,8 @@ def _check_device(self, other):
         elif isinstance(other, Array):
             if self.device != other.device:
                 raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
+        else:
+            raise TypeError(f"Cannot combine an Array with {type(other)}.")
 
     # Helper function to match the type promotion rules in the spec
     def _promote_scalar(self, scalar):
diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py
index 7c64f67..3c4b3d8 100644
--- a/array_api_strict/_elementwise_functions.py
+++ b/array_api_strict/_elementwise_functions.py
@@ -10,17 +10,133 @@
     _real_numeric_dtypes,
     _numeric_dtypes,
     _result_type,
+    _dtype_categories as _dtype_dtype_categories,
 )
 from ._array_object import Array
 from ._flags import requires_api_version
 from ._creation_functions import asarray
 from ._data_type_functions import broadcast_to, iinfo
+from ._helpers import _maybe_normalize_py_scalars
 
 from typing import Optional, Union
 
 import numpy as np
 
 
+def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func):
+    """Base implementation of a binary function, `func_name`, defined for
+       dtypes from `dtype_category`
+    """
+    x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name)
+
+    if x1.device != x2.device:
+        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
+    # Call result type here just to raise on disallowed type combinations
+    _result_type(x1.dtype, x2.dtype)
+    x1, x2 = Array._normalize_two_args(x1, x2)
+    return Array._new(np_func(x1._array, x2._array), device=x1.device)
+
+
+_binary_docstring_template=\
+"""
+Array API compatible wrapper for :py:func:`np.%s <numpy.%s>`.
+
+See its docstring for more information.
+"""
+
+
+def create_binary_func(func_name, dtype_category, np_func):
+    def inner(x1: Array, x2: Array, /) -> Array:
+        return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func)
+    return inner
+
+
+# func_name: dtype_category (must match that from _dtypes.py)
+_binary_funcs = {
+    "add": "numeric",
+    "atan2": "real floating-point",
+    "bitwise_and": "integer or boolean",
+    "bitwise_or": "integer or boolean",
+    "bitwise_xor": "integer or boolean",
+    "_bitwise_left_shift": "integer",  # leading underscore deliberate
+    "_bitwise_right_shift": "integer",
+    # XXX: copysign: real fp or numeric? 
+    "copysign": "real floating-point",
+    "divide": "floating-point",
+    "equal": "all",
+    "greater": "real numeric",
+    "greater_equal": "real numeric",
+    "less": "real numeric",
+    "less_equal": "real numeric",
+    "not_equal": "all",
+    "floor_divide": "real numeric",
+    "hypot": "real floating-point",
+    "logaddexp": "real floating-point",
+    "logical_and": "boolean",
+    "logical_or": "boolean",
+    "logical_xor": "boolean",
+    "maximum": "real numeric",
+    "minimum": "real numeric",
+    "multiply": "numeric",
+    "nextafter": "real floating-point",
+    "pow": "numeric",
+    "remainder": "real numeric",
+    "subtract": "numeric",
+}
+
+
+# map array-api-name : numpy-name
+_numpy_renames = {
+    "atan2": "arctan2",
+    "_bitwise_left_shift": "left_shift",
+    "_bitwise_right_shift": "right_shift",
+    "pow": "power"
+}
+
+
+# create and attach functions to the module
+for func_name, dtype_category in _binary_funcs.items():
+    # sanity check
+    assert dtype_category in _dtype_dtype_categories
+
+    numpy_name = _numpy_renames.get(func_name, func_name)
+    np_func = getattr(np, numpy_name)
+
+    func = create_binary_func(func_name, dtype_category, np_func)
+    func.__name__ = func_name
+
+    func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name)
+
+    vars()[func_name] = func
+
+
+copysign = requires_api_version('2023.12')(copysign)  # noqa: F821
+hypot = requires_api_version('2023.12')(hypot)  # noqa: F821
+maximum = requires_api_version('2023.12')(maximum)  # noqa: F821
+minimum = requires_api_version('2023.12')(minimum)  # noqa: F821
+nextafter = requires_api_version('2024.12')(nextafter)  # noqa: F821
+
+
+def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
+    is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
+    if is_negative:
+        raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
+    return _bitwise_left_shift(x1, x2)   # noqa: F821
+bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__   # noqa: F821
+
+
+def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
+    is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
+    if is_negative:
+        raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
+    return _bitwise_right_shift(x1, x2)   # noqa: F821
+bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__   # noqa: F821
+
+
+# clean up to not pollute the namespace
+del func, create_binary_func
+
+
 def abs(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
@@ -56,23 +172,6 @@ def acosh(x: Array, /) -> Array:
     return Array._new(np.arccosh(x._array), device=x.device)
 
 
-def add(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
-        raise TypeError("Only numeric dtypes are allowed in add")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.add(x1._array, x2._array), device=x1.device)
-
-
 # Note: the function name is different here
 def asin(x: Array, /) -> Array:
     """
@@ -109,23 +208,6 @@ def atan(x: Array, /) -> Array:
     return Array._new(np.arctan(x._array), device=x.device)
 
 
-# Note: the function name is different here
-def atan2(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
-        raise TypeError("Only real floating-point dtypes are allowed in atan2")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.arctan2(x1._array, x2._array), device=x1.device)
-
-
 # Note: the function name is different here
 def atanh(x: Array, /) -> Array:
     """
@@ -138,47 +220,6 @@ def atanh(x: Array, /) -> Array:
     return Array._new(np.arctanh(x._array), device=x.device)
 
 
-def bitwise_and(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if (
-        x1.dtype not in _integer_or_boolean_dtypes
-        or x2.dtype not in _integer_or_boolean_dtypes
-    ):
-        raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.bitwise_and(x1._array, x2._array), device=x1.device)
-
-
-# Note: the function name is different here
-def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
-        raise TypeError("Only integer dtypes are allowed in bitwise_left_shift")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    # Note: bitwise_left_shift is only defined for x2 nonnegative.
-    if np.any(x2._array < 0):
-        raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
-    return Array._new(np.left_shift(x1._array, x2._array), device=x1.device)
-
-
 # Note: the function name is different here
 def bitwise_invert(x: Array, /) -> Array:
     """
@@ -191,67 +232,6 @@ def bitwise_invert(x: Array, /) -> Array:
     return Array._new(np.invert(x._array), device=x.device)
 
 
-def bitwise_or(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if (
-        x1.dtype not in _integer_or_boolean_dtypes
-        or x2.dtype not in _integer_or_boolean_dtypes
-    ):
-        raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.bitwise_or(x1._array, x2._array), device=x1.device)
-
-
-# Note: the function name is different here
-def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
-        raise TypeError("Only integer dtypes are allowed in bitwise_right_shift")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    # Note: bitwise_right_shift is only defined for x2 nonnegative.
-    if np.any(x2._array < 0):
-        raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0")
-    return Array._new(np.right_shift(x1._array, x2._array), device=x1.device)
-
-
-def bitwise_xor(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if (
-        x1.dtype not in _integer_or_boolean_dtypes
-        or x2.dtype not in _integer_or_boolean_dtypes
-    ):
-        raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.bitwise_xor(x1._array, x2._array), device=x1.device)
-
-
 def ceil(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
@@ -372,6 +352,7 @@ def _isscalar(a):
         out[ib] = b[ib]
     return Array._new(out, device=device)
 
+
 def conj(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.conj <numpy.conj>`.
@@ -382,22 +363,6 @@ def conj(x: Array, /) -> Array:
         raise TypeError("Only complex floating-point dtypes are allowed in conj")
     return Array._new(np.conj(x._array), device=x.device)
 
-@requires_api_version('2023.12')
-def copysign(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.copysign <numpy.copysign>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-
-    if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in copysign")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.copysign(x1._array, x2._array), device=x1.device)
 
 def cos(x: Array, /) -> Array:
     """
@@ -421,36 +386,6 @@ def cosh(x: Array, /) -> Array:
     return Array._new(np.cosh(x._array), device=x.device)
 
 
-def divide(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
-        raise TypeError("Only floating-point dtypes are allowed in divide")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.divide(x1._array, x2._array), device=x1.device)
-
-
-def equal(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.equal(x1._array, x2._array), device=x1.device)
-
-
 def exp(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
@@ -487,69 +422,6 @@ def floor(x: Array, /) -> Array:
     return Array._new(np.floor(x._array), device=x.device)
 
 
-def floor_divide(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in floor_divide")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.floor_divide(x1._array, x2._array), device=x1.device)
-
-
-def greater(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in greater")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.greater(x1._array, x2._array), device=x1.device)
-
-
-def greater_equal(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in greater_equal")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.greater_equal(x1._array, x2._array), device=x1.device)
-
-@requires_api_version('2023.12')
-def hypot(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.hypot <numpy.hypot>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
-        raise TypeError("Only real floating-point dtypes are allowed in hypot")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.hypot(x1._array, x2._array), device=x1.device)
-
 def imag(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.imag <numpy.imag>`.
@@ -594,38 +466,6 @@ def isnan(x: Array, /) -> Array:
     return Array._new(np.isnan(x._array), device=x.device)
 
 
-def less(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in less")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.less(x1._array, x2._array), device=x1.device)
-
-
-def less_equal(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in less_equal")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.less_equal(x1._array, x2._array), device=x1.device)
-
-
 def log(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
@@ -670,38 +510,6 @@ def log10(x: Array, /) -> Array:
     return Array._new(np.log10(x._array), device=x.device)
 
 
-def logaddexp(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
-        raise TypeError("Only real floating-point dtypes are allowed in logaddexp")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.logaddexp(x1._array, x2._array), device=x1.device)
-
-
-def logical_and(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
-        raise TypeError("Only boolean dtypes are allowed in logical_and")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.logical_and(x1._array, x2._array), device=x1.device)
-
-
 def logical_not(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
@@ -713,87 +521,6 @@ def logical_not(x: Array, /) -> Array:
     return Array._new(np.logical_not(x._array), device=x.device)
 
 
-def logical_or(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
-        raise TypeError("Only boolean dtypes are allowed in logical_or")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.logical_or(x1._array, x2._array), device=x1.device)
-
-
-def logical_xor(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
-        raise TypeError("Only boolean dtypes are allowed in logical_xor")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.logical_xor(x1._array, x2._array), device=x1.device)
-
-@requires_api_version('2023.12')
-def maximum(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.maximum <numpy.maximum>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in maximum")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error
-    # in that case?
-    return Array._new(np.maximum(x1._array, x2._array), device=x1.device)
-
-@requires_api_version('2023.12')
-def minimum(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.minimum <numpy.minimum>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in minimum")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.minimum(x1._array, x2._array), device=x1.device)
-
-def multiply(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
-        raise TypeError("Only numeric dtypes are allowed in multiply")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.multiply(x1._array, x2._array), device=x1.device)
-
-
 def negative(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
@@ -805,34 +532,6 @@ def negative(x: Array, /) -> Array:
     return Array._new(np.negative(x._array), device=x.device)
 
 
-@requires_api_version('2024.12')
-def nextafter(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.nextafter <numpy.nextafter>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
-        raise TypeError("Only real floating-point dtypes are allowed in nextafter")
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.nextafter(x1._array, x2._array), device=x1.device)
-
-def not_equal(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.not_equal(x1._array, x2._array), device=x1.device)
-
-
 def positive(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
@@ -844,23 +543,6 @@ def positive(x: Array, /) -> Array:
     return Array._new(np.positive(x._array), device=x.device)
 
 
-# Note: the function name is different here
-def pow(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.power <numpy.power>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
-        raise TypeError("Only numeric dtypes are allowed in pow")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.power(x1._array, x2._array), device=x1.device)
-
-
 def real(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.real <numpy.real>`.
@@ -883,22 +565,6 @@ def reciprocal(x: Array, /) -> Array:
         raise TypeError("Only floating-point dtypes are allowed in reciprocal")
     return Array._new(np.reciprocal(x._array), device=x.device)
 
-def remainder(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
-        raise TypeError("Only real numeric dtypes are allowed in remainder")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.remainder(x1._array, x2._array), device=x1.device)
-
-
 def round(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
@@ -979,22 +645,6 @@ def sqrt(x: Array, /) -> Array:
     return Array._new(np.sqrt(x._array), device=x.device)
 
 
-def subtract(x1: Array, x2: Array, /) -> Array:
-    """
-    Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
-
-    See its docstring for more information.
-    """
-    if x1.device != x2.device:
-        raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
-    if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
-        raise TypeError("Only numeric dtypes are allowed in subtract")
-    # Call result type here just to raise on disallowed type combinations
-    _result_type(x1.dtype, x2.dtype)
-    x1, x2 = Array._normalize_two_args(x1, x2)
-    return Array._new(np.subtract(x1._array, x2._array), device=x1.device)
-
-
 def tan(x: Array, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py
new file mode 100644
index 0000000..2258d29
--- /dev/null
+++ b/array_api_strict/_helpers.py
@@ -0,0 +1,37 @@
+"""Private helper routines.
+"""
+
+from ._flags import get_array_api_strict_flags
+from ._dtypes import _dtype_categories
+
+_py_scalars = (bool, int, float, complex)
+
+
+def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name):
+
+    flags = get_array_api_strict_flags()
+    if flags["api_version"] < "2024.12":
+        # scalars will fail at the call site
+        return x1, x2
+
+    _allowed_dtypes = _dtype_categories[dtype_category]
+
+    if isinstance(x1, _py_scalars):
+        if isinstance(x2, _py_scalars):
+            raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}")
+        # x2 must be an array
+        if x2.dtype not in _allowed_dtypes:
+            raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.")
+        x1 = x2._promote_scalar(x1)
+
+    elif isinstance(x2, _py_scalars):
+        # x1 must be an array
+        if x1.dtype not in _allowed_dtypes:
+            raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.")
+        x2 = x1._promote_scalar(x2)
+    else:
+        if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes:
+            raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. "
+                            f"Got {x1.dtype} and {x2.dtype}.")
+    return x1, x2
+
diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py
index 8f185f0..4535d99 100644
--- a/array_api_strict/tests/test_array_object.py
+++ b/array_api_strict/tests/test_array_object.py
@@ -96,12 +96,60 @@ def test_promoted_scalar_inherits_device():
 
     assert y.device == device1
 
+
+BIG_INT = int(1e30)
+
+def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
+    # Test array op scalar. From the spec, the following combinations
+    # are supported:
+
+    # - Python bool for a bool array dtype,
+    # - a Python int within the bounds of the given dtype for integer array dtypes,
+    # - a Python int or float for real floating-point array dtypes
+    # - a Python int, float, or complex for complex floating-point array dtypes
+
+    if ((dtypes == "all"
+         or dtypes == "numeric" and a.dtype in _numeric_dtypes
+         or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
+         or dtypes == "integer" and a.dtype in _integer_dtypes
+         or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes
+         or dtypes == "boolean" and a.dtype in _boolean_dtypes
+         or dtypes == "floating-point" and a.dtype in _floating_dtypes
+         or dtypes == "real floating-point" and a.dtype in _real_floating_dtypes
+        )
+        # bool is a subtype of int, which is why we avoid
+        # isinstance here.
+        and (a.dtype in _boolean_dtypes and type(s) == bool
+             or a.dtype in _integer_dtypes and type(s) == int
+             or a.dtype in _real_floating_dtypes and type(s) in [float, int]
+             or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
+        )):
+        if a.dtype in _integer_dtypes and s == BIG_INT:
+            with assert_raises(OverflowError):
+                func(s)
+            return False
+
+        else:
+            # Only test for no error
+            with suppress_warnings() as sup:
+                # ignore warnings from pow(BIG_INT)
+                sup.filter(RuntimeWarning,
+                           "invalid value encountered in power")
+                func(s)
+            return True
+
+    else:
+        with assert_raises(TypeError):
+            func(s)
+        return False
+
+
 def test_operators():
     # For every operator, we test that it works for the required type
     # combinations and raises TypeError otherwise
     binary_op_dtypes = {
         "__add__": "numeric",
-        "__and__": "integer_or_boolean",
+        "__and__": "integer or boolean",
         "__eq__": "all",
         "__floordiv__": "real numeric",
         "__ge__": "real numeric",
@@ -112,12 +160,12 @@ def test_operators():
         "__mod__": "real numeric",
         "__mul__": "numeric",
         "__ne__": "all",
-        "__or__": "integer_or_boolean",
+        "__or__": "integer or boolean",
         "__pow__": "numeric",
         "__rshift__": "integer",
         "__sub__": "numeric",
-        "__truediv__": "floating",
-        "__xor__": "integer_or_boolean",
+        "__truediv__": "floating-point",
+        "__xor__": "integer or boolean",
     }
     # Recompute each time because of in-place ops
     def _array_vals():
@@ -128,8 +176,6 @@ def _array_vals():
         for d in _floating_dtypes:
             yield asarray(1.0, dtype=d)
 
-
-    BIG_INT = int(1e30)
     for op, dtypes in binary_op_dtypes.items():
         ops = [op]
         if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
@@ -139,40 +185,7 @@ def _array_vals():
         for s in [1, 1.0, 1j, BIG_INT, False]:
             for _op in ops:
                 for a in _array_vals():
-                    # Test array op scalar. From the spec, the following combinations
-                    # are supported:
-
-                    # - Python bool for a bool array dtype,
-                    # - a Python int within the bounds of the given dtype for integer array dtypes,
-                    # - a Python int or float for real floating-point array dtypes
-                    # - a Python int, float, or complex for complex floating-point array dtypes
-
-                    if ((dtypes == "all"
-                         or dtypes == "numeric" and a.dtype in _numeric_dtypes
-                         or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
-                         or dtypes == "integer" and a.dtype in _integer_dtypes
-                         or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
-                         or dtypes == "boolean" and a.dtype in _boolean_dtypes
-                         or dtypes == "floating" and a.dtype in _floating_dtypes
-                        )
-                        # bool is a subtype of int, which is why we avoid
-                        # isinstance here.
-                        and (a.dtype in _boolean_dtypes and type(s) == bool
-                             or a.dtype in _integer_dtypes and type(s) == int
-                             or a.dtype in _real_floating_dtypes and type(s) in [float, int]
-                             or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
-                        )):
-                        if a.dtype in _integer_dtypes and s == BIG_INT:
-                            assert_raises(OverflowError, lambda: getattr(a, _op)(s))
-                        else:
-                            # Only test for no error
-                            with suppress_warnings() as sup:
-                                # ignore warnings from pow(BIG_INT)
-                                sup.filter(RuntimeWarning,
-                                           "invalid value encountered in power")
-                                getattr(a, _op)(s)
-                    else:
-                        assert_raises(TypeError, lambda: getattr(a, _op)(s))
+                    _check_op_array_scalar(dtypes, a, s, getattr(a, _op), _op)
 
                 # Test array op array.
                 for _op in ops:
@@ -203,10 +216,10 @@ def _array_vals():
                                 or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
                                 or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
                                 or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
-                                or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
+                                or dtypes == "integer or boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
                                                                        or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes)
                                 or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
-                                or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
+                                or dtypes == "floating-point" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
                             ):
                                 getattr(x, _op)(y)
                             else:
@@ -214,7 +227,7 @@ def _array_vals():
 
     unary_op_dtypes = {
         "__abs__": "numeric",
-        "__invert__": "integer_or_boolean",
+        "__invert__": "integer or boolean",
         "__neg__": "numeric",
         "__pos__": "numeric",
     }
@@ -223,7 +236,7 @@ def _array_vals():
             if (
                 dtypes == "numeric"
                 and a.dtype in _numeric_dtypes
-                or dtypes == "integer_or_boolean"
+                or dtypes == "integer or boolean"
                 and a.dtype in _integer_or_boolean_dtypes
             ):
                 # Only test for no error
diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py
index 4e1b9cc..0b90f0b 100644
--- a/array_api_strict/tests/test_elementwise_functions.py
+++ b/array_api_strict/tests/test_elementwise_functions.py
@@ -1,6 +1,7 @@
 from inspect import signature, getmodule
 
-from numpy.testing import assert_raises
+from pytest import raises as assert_raises
+from numpy.testing import suppress_warnings
 
 import pytest
 
@@ -19,6 +20,8 @@
 )
 from .._flags import set_array_api_strict_flags
 
+from .test_array_object import _check_op_array_scalar, BIG_INT
+
 import array_api_strict
 
 
@@ -120,6 +123,7 @@ def test_missing_functions():
     # Ensure the above dictionary is complete.
     import array_api_strict._elementwise_functions as mod
     mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod]
+    mod_funcs = [n for n in mod_funcs if not n.startswith("_")]
     assert set(mod_funcs) == set(elementwise_function_input_types)
 
 
@@ -202,3 +206,51 @@ def test_bitwise_shift_error():
     assert_raises(
         ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
     )
+
+
+
+def test_scalars():
+    # mirror test_array_object.py::test_operators()
+    #
+    # Also check that binary functions accept (array, scalar) and (scalar, array)
+    # arguments, and reject (scalar, scalar) arguments.
+
+    # Use the latest version of the standard so that scalars are actually allowed
+    with pytest.warns(UserWarning):
+        set_array_api_strict_flags(api_version="2024.12")
+
+    def _array_vals():
+        for d in _integer_dtypes:
+            yield asarray(1, dtype=d)
+        for d in _boolean_dtypes:
+            yield asarray(False, dtype=d)
+        for d in _floating_dtypes:
+            yield asarray(1.0, dtype=d)
+
+
+    for func_name, dtypes in elementwise_function_input_types.items():
+        func = getattr(_elementwise_functions, func_name)
+        if nargs(func) != 2:
+            continue
+
+        for s in [1, 1.0, 1j, BIG_INT, False]:
+            for a in _array_vals():
+                for func1 in [lambda s: func(a, s), lambda s: func(s, a)]:
+                    allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
+
+                    # only check `func(array, scalar) == `func(array, array)` if
+                    # the former is legal under the promotion rules
+                    if allowed:
+                        conv_scalar = asarray(s, dtype=a.dtype)
+
+                        with suppress_warnings() as sup:
+                            # ignore warnings from pow(BIG_INT)
+                            sup.filter(RuntimeWarning,
+                                       "invalid value encountered in power")
+                            assert func(s, a) == func(conv_scalar, a)
+                            assert func(a, s) == func(a, conv_scalar)
+
+                        with assert_raises(TypeError):
+                            func(s, s)
+
+