Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap behaving unexpected with function that inverts singular matrix #15429

Closed
bjeurissen opened this issue Apr 6, 2023 · 4 comments · Fixed by #22494, openxla/xla#15026 or tensorflow/tensorflow#72030
Assignees
Labels
bug Something isn't working XLA

Comments

@bjeurissen
Copy link

Description

I ran into unexpected behavior when using vmap:

import jax.numpy as jnp
from jax import vmap, __version__
import platform
print(platform.platform())
print(platform.python_version())
print(__version__)

def myfun(x):
    return jnp.linalg.inv(x.T@x)

myfun_many = vmap(myfun, in_axes=2, out_axes=2)

print('test1')
x1 = jnp.array([[1.0,3.0],[-5.0,1.0]])
y1 = myfun(x1)
print(y1)

x2 = jnp.stack((x1,x1,x1),2)
assert (x2[:,:,0]==x1).all()
assert (x2[:,:,1]==x1).all()
assert (x2[:,:,2]==x1).all()
y2 = myfun_many(x2)
print(y2[:,:,0]) # this is the same as y1
print(y2[:,:,1]) # this is the same as y1
print(y2[:,:,2]) # this is the same as y1


print('test2')
x1 = jnp.array([[1.0,1.0],[1.0,1.0]])
y1 = myfun(x1)
print(y1)

x2 = jnp.stack((x1,x1,x1),2)
assert (x2[:,:,0]==x1).all()
assert (x2[:,:,1]==x1).all()
assert (x2[:,:,2]==x1).all()
y2 = myfun_many(x2)
print(y2[:,:,0]) # this is not the same as y1!
print(y2[:,:,1]) # this is not the same as y1!
print(y2[:,:,2]) # this is not the same as y1!

which outputs:

macOS-13.3-arm64-arm-64bit
3.11.1
0.4.8
test1
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
test2
[[ inf -inf]
 [-inf  inf]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]

In the first test, one can appreciate that vmap correctly vectorized myfun when x.T@x can be inverted.

However, when x.T@x cannot be inverted, myfun "correctly" returns Infs, whereas vmap of myfun returns something else.

What jax/jaxlib version are you using?

jax 0.4.8

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.11.1, macOS-13.3-arm64-arm-64bit

NVIDIA GPU info

No response

@bjeurissen bjeurissen added the bug Something isn't working label Apr 6, 2023
@alonfnt
Copy link
Contributor

alonfnt commented May 13, 2023

I can reproduce this on CPU. On GPU it works correctly.
CPU

Linux-5.19.0-41-generic-x86_64-with-glibc2.35
3.10.6
0.4.8
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
test1
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
test2
[[ inf -inf]
 [-inf  inf]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]

GPU:

Linux-5.19.0-41-generic-x86_64-with-glibc2.35
3.10.6
0.4.8
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
test1
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
test2
[[ inf -inf]
 [-inf  inf]]
[[ inf -inf]
 [-inf  inf]]
[[ inf -inf]
 [-inf  inf]]
[[ inf -inf]
 [-inf  inf]]

@jakevdp jakevdp self-assigned this Jun 21, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 21, 2023

Shorter repro:

import jax
import jax.numpy as jnp

x = jnp.ones((1, 2, 2))
print(jnp.linalg.inv(x))
# [[[ inf -inf]
#   [-inf  inf]]]

print(jax.vmap(jnp.linalg.inv)(x))
# [[[ 2. -1.]
#   [-1.  1.]]]

It probably has something to do with the custom_linear_solve batching rule. I'm going to dig into it.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 28, 2023

I think I'm getting closer – this looks like it somehow comes from the batching rule of jax.lax.triangular_solve:

import jax

def solve(x, y):
  return jax.lax.linalg.triangular_solve(x, y, left_side=True)

x = jnp.array([[1., 1.], [1., 0.]])
y = jnp.array([[1.], [0.]])

print(solve(x, y))
# [[nan]
#  [nan]]

print(jax.vmap(solve)(x[None], y[None])[0])
# [[1.]
#  [0.]]

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 10, 2024

This is still an issue as of JAX v0.4.26.

@jakevdp jakevdp assigned hawkinsp and unassigned jakevdp Jun 11, 2024
copybara-service bot pushed a commit that referenced this issue Jul 17, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes #3589
Fixes #15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jul 17, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 17, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes #3589
Fixes #15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#14865 from terryysun:terryysun/syntax_sugar_debug_flag 4b1ba992c4e90d3b99044b0d15f0d61cdc08fa59
PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653562611
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653562611
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment