Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched TriangularSolve of singular matrix returns incorrect results #3891

Closed
jakevdp opened this issue Jun 28, 2023 · 2 comments
Closed

Batched TriangularSolve of singular matrix returns incorrect results #3891

jakevdp opened this issue Jun 28, 2023 · 2 comments
Assignees
Labels
bug Something isn't working CPU Related to XLA on CPU

Comments

@jakevdp
Copy link
Contributor

jakevdp commented Jun 28, 2023

Here's a short repro in JAX, that more or less passes the inputs directly to TriangularSolveOp:

import jax
import jax.numpy as jnp

def solve(x, y):
  return jax.lax.linalg.triangular_solve(x, y, left_side=True)

x = jnp.array([[1., 1.], [0., 0.]])
y = jnp.array([[1], [1.]])

print(solve(x, y))
# [[-inf]
#  [ inf]]
print(solve(x[None], y[None])[0])
# [[0.]
#  [1.]]

I would expect the second output to be identical to the first. This appears to be the root cause of the issue reported in jax-ml/jax#15429

@Qiustander
Copy link

Is this issue coming from the batching rule of triangular_solve? Had it solved?

@hawkinsp
Copy link
Member

Fixed by #15026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CPU Related to XLA on CPU
Projects
None yet
Development

No branches or pull requests

6 participants