Skip to content

Commit

Permalink
Only require axis to be negative in vecdot and cross
Browse files Browse the repository at this point in the history
Nonnegative axes and negative axes less than the smaller of the two arrays are
unspecified.

This is because it is ambiguous in these cases whether the
dimension should refer to the axis before or after broadcasting. Preciously,
the spec stated it should refer to the dimension before broadcasting, but this
deviates from NumPy gufunc behavior, and results in ambiguous and confusing
situations, where, for instance, the result of a the function is different
when the inputs are manually broadcasted together.

Also clean up some of the cross text a little bit since the computed dimension
must be exactly size 3.

Fixes data-apis#724
Fixes data-apis#617

See the discussion in those issues for more details.
  • Loading branch information
asmeurer committed Feb 2, 2024
1 parent 95332bb commit 7728c98
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 4 additions & 5 deletions src/array_api_stubs/_draft/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
Parameters
----------
x1: array
first input array. Must have a numeric data type.
first input array. Must have a numeric data type. The size of the axis over which the cross product is to be computed must be equal to 3.
x2: array
second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Must have a numeric data type.
second input array. Must be broadcast compatible with ``x1`` along all axes other than the axis along which the cross-product is computed (see :ref:`broadcasting`). The size of the axis over which the cross product is to be computed must be equal to 3. Must have a numeric data type.
.. note::
The compute axis (dimension) must not be broadcasted.
axis: int
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``.
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``.
Returns
-------
Expand All @@ -110,8 +110,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
**Raises**
- if the size of the axis over which to compute the cross product is not equal to ``3``.
- if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``.
- if the size of the axis over which to compute the cross product is not equal to ``3`` (before broadcasting) for both ``x1`` and ``x2``.
"""


Expand Down
2 changes: 1 addition & 1 deletion src/array_api_stubs/_draft/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
The contracted axis (dimension) must not be broadcasted.
axis: int
axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
Returns
-------
Expand Down

0 comments on commit 7728c98

Please sign in to comment.