Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU #3589

Closed
dpfau opened this issue Jun 28, 2020 · 7 comments · Fixed by #22494, openxla/xla#15026 or tensorflow/tensorflow#72030

Comments

@dpfau
Copy link
Contributor

dpfau commented Jun 28, 2020

This is an issue with the CPU backend when running triangular_solve. If I provide a single low-rank matrix to triangular_solve (i.e. an upper triangular matrix with zeros on the diagonal) it will solve correctly by back-substitution until it reaches the zero, and return Inf or NaN for all columns after that. If I run a batch of matrices through triangular_solve, however, it will return zero in the later columns (and then a nonsense result in the last column) instead of Inf/NaN. This error is propagated through to jnp.linalg.solve, as can be seen in the following example:

from jax import lax_linalg
import numpy as np
import jax.numpy as jnp

n = 12
k = [0, 3, 5, 0, 1]  # corank of each batch element
u = np.triu(np.random.rand(len(k), n, n) + 1)
for i in range(len(k)):
  if k[i] > 0:
    u[i, -k[i]:] = 0

foo = lax_linalg.triangular_solve(
    u, np.ones((len(k), 1, n)),
    left_side=False, transpose_a=False, lower=False)
bar = jnp.linalg.solve(u.transpose((0, 2, 1)), np.ones((len(k), n, 1))).transpose((0, 2, 1))

for i in range(len(k)):
  print(lax_linalg.triangular_solve(
    u[i], np.ones((1, n)),
    left_side=False, transpose_a=False, lower=False))
  print('')
  print(foo[i])
  print('')
  print(bar[i])
  print('\n\n')

The first result (triangular_solve on a single matrix) is correct, while following results (triangular_solve and solve on a batch of matrices) are incorrect. On GPU, this bug is not present.

@dpfau dpfau changed the title solve and triangular_solve fail to return inf for batches of singular matrices on CPU solve and triangular_solve fail to return Inf for batches of singular matrices on CPU Jun 28, 2020
@hawkinsp
Copy link
Collaborator

For the batched case the implementation switches to a completely different algorithm (actually, the same implementation used on TPU): https://github.com/google/jax/blob/7b57dc8c8043163a5e649ba66143ccef880d7d58/jax/lax_linalg.py#L436
For the batch 1 case, we call LAPACK TRSM which acts as you say.

The batched case calls into XLA, which uses an algorithm inspired by MAGMA that inverts diagonal blocks: https://github.com/tensorflow/tensorflow/blob/bd006c354f11f9045d344f3e48b47be9f8368dac/tensorflow/compiler/xla/service/triangular_solve_expander.cc#L439

@hawkinsp
Copy link
Collaborator

I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, scipy.linalg.solve_triangular would raise a singular matrix exception in the corresponding situation.

@dpfau
Copy link
Contributor Author

dpfau commented Jun 29, 2020 via email

@dpfau
Copy link
Contributor Author

dpfau commented Jun 29, 2020 via email

@hawkinsp
Copy link
Collaborator

Wouldn't it suffice to look for the first 0 on the diagonal of the input matrix?

(I'm also just trying to understand what API contract you expect, because the usual contract says "this is an illegal input". It's possible we can make the XLA algorithm mimic the behavior of the usual TRSM algorithm in this case, but it's not clear to me the behavior in the singular case is actually well defined without also fixing the choice of algorithm.)

@dpfau
Copy link
Contributor Author

dpfau commented Jun 29, 2020 via email

copybara-service bot pushed a commit that referenced this issue Jul 17, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes #3589
Fixes #15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jul 17, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 17, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes #3589
Fixes #15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#14865 from terryysun:terryysun/syntax_sugar_debug_flag 4b1ba992c4e90d3b99044b0d15f0d61cdc08fa59
PiperOrigin-RevId: 653274967
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653562611
@hawkinsp
Copy link
Collaborator

hawkinsp commented Jul 18, 2024

As of the current state, you'll get an output containing NaNs if you pass a singular matrix. They will not necessarily appear starting at the relevant column: that depends on the algorithm choice, and because some of the algorithms involve matrix multiplication they will have the effect of smearing nans across the output if any are present.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 18, 2024
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653562611
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment