We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Here's a short repro in JAX, that more or less passes the inputs directly to TriangularSolveOp:
TriangularSolveOp
import jax import jax.numpy as jnp def solve(x, y): return jax.lax.linalg.triangular_solve(x, y, left_side=True) x = jnp.array([[1., 1.], [0., 0.]]) y = jnp.array([[1], [1.]]) print(solve(x, y)) # [[-inf] # [ inf]] print(solve(x[None], y[None])[0]) # [[0.] # [1.]]
I would expect the second output to be identical to the first. This appears to be the root cause of the issue reported in jax-ml/jax#15429
The text was updated successfully, but these errors were encountered:
Is this issue coming from the batching rule of triangular_solve? Had it solved?
Sorry, something went wrong.
jax.numpy.linalg.solve
Fixed by #15026
xla-rotation
No branches or pull requests
Here's a short repro in JAX, that more or less passes the inputs directly to
TriangularSolveOp
:I would expect the second output to be identical to the first. This appears to be the root cause of the issue reported in jax-ml/jax#15429
The text was updated successfully, but these errors were encountered: