From 47e6da3332be78bdd8d058291096d0147657088f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 18 Jul 2024 04:08:55 -0700 Subject: [PATCH] Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes https://github.com/google/jax/issues/3589 Fixes https://github.com/google/jax/issues/15429 PiperOrigin-RevId: 653562611 --- CHANGELOG.md | 2 ++ tests/linalg_test.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83b3ab874f2a..15c5633c62b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ Remember to align the itemized text with the first line of an item within a list * Bug fixes * Fixed a bug that meant that negative static_argnums to a jit were mishandled by the jit dispatch fast path. + * Fixed a bug that meant triangular solves of batches of singular matrices + produce nonsensical finite values, instead of inf or nan (#3589, #15429). ## jax 0.4.30 (June 18, 2024) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2a64b95b9452..dd0ae38d9aa4 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,7 @@ from functools import partial import itertools +import unittest import numpy as np import scipy @@ -33,6 +34,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import xla_extension_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -1623,6 +1625,15 @@ def testTriangularSolveGradPrecision(self): (a, b), (a, b)) + @unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30") + def testTriangularSolveSingularBatched(self): + x = jnp.array([[1, 1], [0, 0]], dtype=np.float32) + y = jnp.array([[1], [1.]], dtype=np.float32) + out = jax.lax.linalg.triangular_solve(x[None], y[None], left_side=True) + # x is singular. The triangular solve may contain either nans or infs, but + # it should not consist of only finite values. + self.assertFalse(np.all(np.isfinite(out))) + @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],