-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap behaving unexpected with function that inverts singular matrix #15429
Closed
bjeurissen opened this issue
Apr 6, 2023
· 4 comments
· Fixed by #22494, openxla/xla#15026 or tensorflow/tensorflow#72030
Closed
vmap behaving unexpected with function that inverts singular matrix #15429
bjeurissen opened this issue
Apr 6, 2023
· 4 comments
· Fixed by #22494, openxla/xla#15026 or tensorflow/tensorflow#72030
Comments
I can reproduce this on CPU. On GPU it works correctly.
GPU:
|
Shorter repro: import jax
import jax.numpy as jnp
x = jnp.ones((1, 2, 2))
print(jnp.linalg.inv(x))
# [[[ inf -inf]
# [-inf inf]]]
print(jax.vmap(jnp.linalg.inv)(x))
# [[[ 2. -1.]
# [-1. 1.]]] It probably has something to do with the |
I think I'm getting closer – this looks like it somehow comes from the batching rule of import jax
def solve(x, y):
return jax.lax.linalg.triangular_solve(x, y, left_side=True)
x = jnp.array([[1., 1.], [1., 0.]])
y = jnp.array([[1.], [0.]])
print(solve(x, y))
# [[nan]
# [nan]]
print(jax.vmap(solve)(x[None], y[None])[0])
# [[1.]
# [0.]] |
This is still an issue as of JAX v0.4.26. |
copybara-service bot
pushed a commit
that referenced
this issue
Jul 17, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes #3589 Fixes #15429 PiperOrigin-RevId: 653274967
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this issue
Jul 17, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653274967
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this issue
Jul 17, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653274967
copybara-service bot
pushed a commit
that referenced
this issue
Jul 18, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes #3589 Fixes #15429 PiperOrigin-RevId: 653274967
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this issue
Jul 18, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653274967
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this issue
Jul 18, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#14865 from terryysun:terryysun/syntax_sugar_debug_flag 4b1ba992c4e90d3b99044b0d15f0d61cdc08fa59 PiperOrigin-RevId: 653274967
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this issue
Jul 18, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653562611
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this issue
Jul 18, 2024
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653562611
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
I ran into unexpected behavior when using vmap:
which outputs:
In the first test, one can appreciate that vmap correctly vectorized myfun when x.T@x can be inverted.
However, when x.T@x cannot be inverted, myfun "correctly" returns Infs, whereas vmap of myfun returns something else.
What jax/jaxlib version are you using?
jax 0.4.8
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.11.1, macOS-13.3-arm64-arm-64bit
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: