diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index f015ae15bca5..b1482103129d 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -53,9 +53,26 @@ def cholesky(a): @_wraps(np.linalg.svd) -@partial(jit, static_argnames=('full_matrices', 'compute_uv')) -def svd(a, full_matrices=True, compute_uv=True): +@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian')) +def svd(a, full_matrices: bool = True, compute_uv: bool = True, + hermitian: bool = False): a = _promote_arg_dtypes(jnp.asarray(a)) + if hermitian: + w, v = lax_linalg.eigh(a) + s = lax.abs(v) + if compute_uv: + sign = lax.sign(v) + idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) + s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) + s = lax.rev(s, dimensions=[s.ndim - 1]) + idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) + sign = lax.rev(sign, dimensions=[s.ndim - 1]) + u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1) + vh = _H(u * sign[..., None, :]) + return u, s, vh + else: + return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1]) + return lax_linalg.svd(a, full_matrices, compute_uv) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 92b2dbe284c4..b447f0b9093b 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -527,18 +527,19 @@ def testNorm(self, shape, dtype, ord, axis, keepdims): self._CompileAndCheck(jnp_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( + {"testcase_name": "_n={}_full_matrices={}_compute_uv={}_hermitian={}".format( jtu.format_shape_dtype_string(b + (m, n), dtype), full_matrices, - compute_uv), + compute_uv, hermitian), "b": b, "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices, - "compute_uv": compute_uv} + "compute_uv": compute_uv, "hermitian": hermitian} for b in [(), (3,), (2, 3)] for m in [0, 2, 7, 29, 53] for n in [0, 2, 7, 29, 53] for dtype in float_types + complex_types for full_matrices in [False, True] - for compute_uv in [False, True])) - def testSVD(self, b, m, n, dtype, full_matrices, compute_uv): + for compute_uv in [False, True] + for hermitian in ([False, True] if m == n else [False]))) + def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian): if (jnp.issubdtype(dtype, np.complexfloating) and jtu.device_under_test() == "tpu"): raise unittest.SkipTest("No complex SVD implementation") @@ -551,7 +552,10 @@ def norm(x): return norm / (max(1, m, n) * jnp.finfo(dtype).eps) a, = args_maker() - out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) + if hermitian: + a = a + np.conj(T(a)) + out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv, + hermitian=hermitian) if compute_uv: # Check the reconstructed matrices if full_matrices: