-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
solve and triangular_solve fail to return Inf for batches of singular matrices on CPU #3589
solve and triangular_solve fail to return Inf for batches of singular matrices on CPU #3589
Comments
For the batched case the implementation switches to a completely different algorithm (actually, the same implementation used on TPU): https://github.com/google/jax/blob/7b57dc8c8043163a5e649ba66143ccef880d7d58/jax/lax_linalg.py#L436 The batched case calls into XLA, which uses an algorithm inspired by MAGMA that inverts diagonal blocks: https://github.com/tensorflow/tensorflow/blob/bd006c354f11f9045d344f3e48b47be9f8368dac/tensorflow/compiler/xla/service/triangular_solve_expander.cc#L439 |
I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, |
I am working on a fix so that the gradient of the LU decomposition returns
the correct value even if the input matrix is singular, and I use the
columns of Infs/NaNs to identify what the rank of the matrix is, so in this
case it actually is important that it returns Inf/NaN instead of zeros or
an error.
…On Mon, Jun 29, 2020 at 5:18 PM Peter Hawkins ***@***.***> wrote:
I'm wondering if you care about the values returned if the matrix is
singular, or whether you would be happy to get, say, a matrix full of NaNs
out for that batch element. Note that, say, scipy.linalg.solve_triangular
would raise a singular matrix exception in the corresponding situation.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#3589 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AABDACFPOTFPWUAKKWOXU53RZC5GTANCNFSM4OKUPODQ>
.
|
This same trick is currently used by the gradient np.linalg.det when
dealing with singular matrices, which leads me to believe that it may also
fail in this case.
…On Mon, Jun 29, 2020 at 5:35 PM David Pfau ***@***.***> wrote:
I am working on a fix so that the gradient of the LU decomposition returns
the correct value even if the input matrix is singular, and I use the
columns of Infs/NaNs to identify what the rank of the matrix is, so in this
case it actually is important that it returns Inf/NaN instead of zeros or
an error.
On Mon, Jun 29, 2020 at 5:18 PM Peter Hawkins ***@***.***>
wrote:
> I'm wondering if you care about the values returned if the matrix is
> singular, or whether you would be happy to get, say, a matrix full of NaNs
> out for that batch element. Note that, say, scipy.linalg.solve_triangular
> would raise a singular matrix exception in the corresponding situation.
>
> —
> You are receiving this because you authored the thread.
> Reply to this email directly, view it on GitHub
> <#3589 (comment)>, or
> unsubscribe
> <https://github.com/notifications/unsubscribe-auth/AABDACFPOTFPWUAKKWOXU53RZC5GTANCNFSM4OKUPODQ>
> .
>
|
Wouldn't it suffice to look for the first 0 on the diagonal of the input matrix? (I'm also just trying to understand what API contract you expect, because the usual contract says "this is an illegal input". It's possible we can make the XLA algorithm mimic the behavior of the usual TRSM algorithm in this case, but it's not clear to me the behavior in the singular case is actually well defined without also fixing the choice of algorithm.) |
I'm worried that would miss Inf/NaN due to underflow for values that aren't
identically zero. Different backends might have different tolerances and I
want to make sure I catch everything. It's possible I'm being too cautious
though.
I definitely *don't* want it to raise an error. I want this to work on
singular matrices - triangular_solve still gives useful results for all
columns up to the rank of the matrix.
…On Mon, Jun 29, 2020 at 6:02 PM Peter Hawkins ***@***.***> wrote:
Wouldn't it suffice to look for the first 0 on the diagonal of the input
matrix?
(I'm also just trying to understand what API contract you expect, because
the usual contract says "this is an illegal input". It's possible we can
make the XLA algorithm mimic the behavior of the usual TRSM algorithm in
this case, but it's not clear to me the behavior in the singular case is
actually well defined without also fixing the choice of algorithm.)
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#3589 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AABDACHBFF2SP6WQN2OQIATRZDCJVANCNFSM4OKUPODQ>
.
|
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes #3589 Fixes #15429 PiperOrigin-RevId: 653274967
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653274967
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653274967
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes #3589 Fixes #15429 PiperOrigin-RevId: 653274967
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653274967
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#14865 from terryysun:terryysun/syntax_sugar_debug_flag 4b1ba992c4e90d3b99044b0d15f0d61cdc08fa59 PiperOrigin-RevId: 653274967
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653562611
As of the current state, you'll get an output containing NaNs if you pass a singular matrix. They will not necessarily appear starting at the relevant column: that depends on the algorithm choice, and because some of the algorithms involve matrix multiplication they will have the effect of smearing nans across the output if any are present. |
…ting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes jax-ml/jax#3589 Fixes jax-ml/jax#15429 PiperOrigin-RevId: 653562611
This is an issue with the CPU backend when running
triangular_solve
. If I provide a single low-rank matrix to triangular_solve (i.e. an upper triangular matrix with zeros on the diagonal) it will solve correctly by back-substitution until it reaches the zero, and return Inf or NaN for all columns after that. If I run a batch of matrices throughtriangular_solve
, however, it will return zero in the later columns (and then a nonsense result in the last column) instead of Inf/NaN. This error is propagated through tojnp.linalg.solve
, as can be seen in the following example:The first result (
triangular_solve
on a single matrix) is correct, while following results (triangular_solve
andsolve
on a batch of matrices) are incorrect. On GPU, this bug is not present.The text was updated successfully, but these errors were encountered: