Skip to content

Commit

Permalink
Merge pull request #8421 from hawkinsp:svd
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 406820268
  • Loading branch information
jax authors committed Nov 1, 2021
2 parents 32319e1 + 05e6f84 commit 5ae0795
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
21 changes: 19 additions & 2 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,26 @@ def cholesky(a):


@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def svd(a, full_matrices=True, compute_uv=True):
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False):
a = _promote_arg_dtypes(jnp.asarray(a))
if hermitian:
w, v = lax_linalg.eigh(a)
s = lax.abs(v)
if compute_uv:
sign = lax.sign(v)
idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1)
s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
s = lax.rev(s, dimensions=[s.ndim - 1])
idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
sign = lax.rev(sign, dimensions=[s.ndim - 1])
u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
vh = _H(u * sign[..., None, :])
return u, s, vh
else:
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])

return lax_linalg.svd(a, full_matrices, compute_uv)


Expand Down
16 changes: 10 additions & 6 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,18 +527,19 @@ def testNorm(self, shape, dtype, ord, axis, keepdims):
self._CompileAndCheck(jnp_fn, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}_hermitian={}".format(
jtu.format_shape_dtype_string(b + (m, n), dtype), full_matrices,
compute_uv),
compute_uv, hermitian),
"b": b, "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices,
"compute_uv": compute_uv}
"compute_uv": compute_uv, "hermitian": hermitian}
for b in [(), (3,), (2, 3)]
for m in [0, 2, 7, 29, 53]
for n in [0, 2, 7, 29, 53]
for dtype in float_types + complex_types
for full_matrices in [False, True]
for compute_uv in [False, True]))
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv):
for compute_uv in [False, True]
for hermitian in ([False, True] if m == n else [False])))
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
Expand All @@ -551,7 +552,10 @@ def norm(x):
return norm / (max(1, m, n) * jnp.finfo(dtype).eps)

a, = args_maker()
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
if hermitian:
a = a + np.conj(T(a))
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv,
hermitian=hermitian)
if compute_uv:
# Check the reconstructed matrices
if full_matrices:
Expand Down

0 comments on commit 5ae0795

Please sign in to comment.