Skip to content

Commit

Permalink
Merge pull request #35 from asmeurer/torch-linalg2
Browse files Browse the repository at this point in the history
More fixes for torch linalg extension
  • Loading branch information
asmeurer authored Mar 31, 2023
2 parents b32a5b3 + 05204b3 commit 5d3a92c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/array-api-tests-numpy-1-21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Array API Tests (NumPy 1.21)
on: [push, pull_request]

jobs:
array-api-tests-numpy:
array-api-tests-numpy-1-21:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: numpy
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/array-api-tests-numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Array API Tests (NumPy Latest)
on: [push, pull_request]

jobs:
array-api-tests-numpy-1-21:
array-api-tests-numpy-latest:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: numpy
30 changes: 29 additions & 1 deletion array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch_linalg.cross(x1, x2, dim=axis)

__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = torch.broadcast_tensors(x1, x2)
x1_ = torch.moveaxis(x1_, axis, -1)
x2_ = torch.moveaxis(x2_, axis, -1)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)

__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
'vecdot', 'solve']

del linalg_all
1 change: 1 addition & 0 deletions numpy-1-21-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
Expand Down
4 changes: 2 additions & 2 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__im
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]


# Mac-only bug (overflow near float max)
# array_api_tests/test_operators_and_elementwise_functions.py::test_log1p
# overflow near float max
array_api_tests/test_operators_and_elementwise_functions.py::test_log1p

# torch doesn't handle shifting by more than the bit size correctly
# https://github.com/pytorch/pytorch/issues/70904
Expand Down

0 comments on commit 5d3a92c

Please sign in to comment.