Skip to content

Commit

Permalink
BUG: Fix the implementation of numpy.array_api.vecdot (#21928)
Browse files Browse the repository at this point in the history
* Fix the implementation of numpy.array_api.vecdot

See https://data-apis.org/array-api/latest/API_specification/generated/signatures.linear_algebra_functions.vecdot.html

* Use moveaxis + matmul instead of einsum in vecdot
  • Loading branch information
asmeurer authored and charris committed Sep 7, 2022
1 parent e18dd98 commit 6cc4183
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion numpy/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,18 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in vecdot')
return tensordot(x1, x2, axes=((axis,), (axis,)))
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
x1_ = np.moveaxis(x1_, axis, -1)
x2_ = np.moveaxis(x2_, axis, -1)

res = x1_[..., None, :] @ x2_[..., None]
return Array._new(res[..., 0, 0])


# Note: the name here is different from norm(). The array API norm is split
Expand Down

0 comments on commit 6cc4183

Please sign in to comment.