Skip to content

Commit

Permalink
now only torch backend of jax.numpy.linalg.svd is failing due to "Run…
Browse files Browse the repository at this point in the history
…timeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead." when the test function calls np.asarray to the returned value
  • Loading branch information
Jin Wang committed Jul 15, 2024
1 parent b9cf1cd commit 3bb5b66
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ivy/functional/frontends/jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ def solve(a, b):

@to_ivy_arrays_and_back
def svd(a, /, *, full_matrices=True, compute_uv=True, hermitian=None):
# TODO: handle hermitian
if not compute_uv:
return ivy.svdvals(a)
return ivy.svd(a, full_matrices=full_matrices)
return ivy.svdvals(a).astype(ivy.float64)
ret = ivy.svd(a, full_matrices=full_matrices)
return tuple([ x.astype(ivy.float64) for x in ret])


@to_ivy_arrays_and_back
Expand Down

0 comments on commit 3bb5b66

Please sign in to comment.