Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy.linalg.inv returns spurious results when called with array of matrices #21785

Closed
johannahaffner opened this issue Jun 10, 2024 · 2 comments
Assignees
Labels
bug Something isn't working duplicate This issue or pull request already exists

Comments

@johannahaffner
Copy link

Description

I encountered unexpected behavior while using jnp.linalg.inv to invert an array of matrices (I wanted to get an array of their inverses).

Specifically,

import jax.numpy as jnp

invertible = jnp.eye(2)
not_invertible = jnp.array([[1,0], [1,0]])
array_of_matrices = jnp.array([invertible, not_invertible])

print(jnp.linalg.inv(not_invertible))  # Expected result
print(jnp.linalg.inv(array_of_matrices)[1])  # Spurious inverse of not_invertible

This isn't mentioned in the documentation, which specifies that the return type is

  • Array of shape (..., N, N) containing the inverse of the input.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.11.8 (v3.11.8:db85d51d3e, Feb 6 2024, 18:02:37) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='bs-mbpas-0019', release='23.3.0', version='Darwin Kernel Version 23.3.0: Wed Dec 20 21:33:31 PST 2023; root:xnu-10002.81.5~7/RELEASE_ARM64_T8112', machine='arm64')

@johannahaffner johannahaffner added the bug Something isn't working label Jun 10, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 10, 2024

Thanks for the report! This is a known XLA CPU issue, first reported in #15429 and tracked here: openxla/xla#3891

Unfortunately there hasn't been much traction on the XLA side to get this fixed

@jakevdp jakevdp self-assigned this Jun 10, 2024
@jakevdp jakevdp added the duplicate This issue or pull request already exists label Jun 10, 2024
@johannahaffner
Copy link
Author

Thank you for the lightning fast reply! Good to know, I'm working around it for now. And sorry for the duplicate, I hadn't spotted that.

@superbobry superbobry closed this as not planned Won't fix, can't repro, duplicate, stale Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

3 participants