Skip to content

Commit

Permalink
Fix or disable some tests that fail when using a Eigen BLAS with AVX …
Browse files Browse the repository at this point in the history
…vectorization.

PiperOrigin-RevId: 658047868
  • Loading branch information
hawkinsp authored and nitins17 committed Aug 27, 2024
1 parent d8774a4 commit ed9780c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
2 changes: 2 additions & 0 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def testPolar(
tol = 650 * float(jnp.finfo(matrix.dtype).eps)
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
with self.subTest('Test unitarity.'):
if jtu.test_device_matches(["cpu"]):
tol = max(tol, 1e-8)
self.assertAllClose(
eye_mat, should_be_eye, atol=tol * 1000 * min(shape))

Expand Down
34 changes: 14 additions & 20 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ def check_left_eigenvectors(a, w, vl):
if compute_right_eigenvectors:
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])

self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
rtol=1e-3)
# TODO(phawkins): we are seeing nondeterminism in LAPACK routines with
# avx enabled, because for Eigen BLAS nrm2 has an alignment dependence.
# self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
# rtol=1e-3)

@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
Expand Down Expand Up @@ -1860,17 +1862,6 @@ def expm(x):
return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16)
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)
@jtu.sample_product(
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.run_on_devices("cpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker)
self._CompileAndCheck(jsp.linalg.schur, args_maker)

@jtu.sample_product(
shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)],
Expand Down Expand Up @@ -1900,9 +1891,11 @@ def func(x):
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp)
scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp)
self._CheckAgainstNumpy(jnp_fun, scp_fun, args_maker, check_dtypes=False,
tol={np.complex64: 1e-5, np.complex128: 1e-6})
self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5)
self._CheckAgainstNumpy(
jnp_fun, scp_fun, args_maker, check_dtypes=False,
tol={np.float32: 2e-3,np.complex64: 2e-3, np.complex128: 1e-6})
# TODO(phawkins): nondeterminism due to alignment.
# self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5)

@jtu.sample_product(
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
Expand Down Expand Up @@ -2125,10 +2118,11 @@ def test_tridiagonal_solve(self, dtype):
@jtu.run_on_devices("cpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(osp.linalg.schur, lax.linalg.schur, args_maker)
self._CompileAndCheck(lax.linalg.schur, args_maker)
args = rng(shape, dtype)
Ts, Ss = lax.linalg.schur(args)
eps = np.finfo(dtype).eps
self.assertAllClose(args, Ss @ Ts @ jnp.conj(Ss.T), atol=eps * 600)
self.assertAllClose(np.eye(*shape, dtype=dtype), Ss @ jnp.conj(Ss.T), atol=eps * 100)

@jtu.sample_product(
shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)],
Expand Down

0 comments on commit ed9780c

Please sign in to comment.