diff --git a/src/array_api_stubs/_draft/linalg.py b/src/array_api_stubs/_draft/linalg.py index d05b53a9f..1e7efa95e 100644 --- a/src/array_api_stubs/_draft/linalg.py +++ b/src/array_api_stubs/_draft/linalg.py @@ -83,15 +83,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: Parameters ---------- x1: array - first input array. Must have a numeric data type. + first input array. Must have a numeric data type. The size of the axis over which the cross product is to be computed must be equal to 3. x2: array - second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Must have a numeric data type. + second input array. Must be broadcast compatible with ``x1`` along all axes other than the axis along which the cross-product is computed (see :ref:`broadcasting`). The size of the axis over which the cross product is to be computed must be equal to 3. Must have a numeric data type. .. note:: The compute axis (dimension) must not be broadcasted. axis: int - the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``. + the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``. Returns ------- @@ -110,8 +110,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: **Raises** - - if the size of the axis over which to compute the cross product is not equal to ``3``. - - if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``. + - if the size of the axis over which to compute the cross product is not equal to ``3`` (before broadcasting) for both ``x1`` and ``x2``. """ diff --git a/src/array_api_stubs/_draft/linear_algebra_functions.py b/src/array_api_stubs/_draft/linear_algebra_functions.py index 96f082bd5..eea898a6b 100644 --- a/src/array_api_stubs/_draft/linear_algebra_functions.py +++ b/src/array_api_stubs/_draft/linear_algebra_functions.py @@ -141,7 +141,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: The contracted axis (dimension) must not be broadcasted. axis: int - axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``. + the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``. Returns -------