Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of np.linalg.lstsq() via SVD #2744

Merged
merged 7 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ jax.numpy.linalg
eigvals
eigvalsh
inv
lstsq
matrix_power
matrix_rank
multi_dot
Expand Down
53 changes: 53 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,56 @@ def solve(a, b):
for func in get_module_functions(np.linalg):
if func.__name__ not in globals():
globals()[func.__name__] = _not_implemented(func)


@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:

1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.

The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_arg_dtypes(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
if b.ndim != 2:
raise TypeError(
f"{b_original_ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
dtype = a.dtype
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
elif rcond < 0:
rcond = jnp.finfo(dtype).eps
u, s, vt = svd(a, full_matrices=False)
mask = s >= rcond * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = jnp.asarray([])
else:
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
40 changes: 40 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,46 @@ def testMultiDot(self, shapes, dtype, rng_factory):
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
atol=tol, rtol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs={}_rhs={}_lowrank={}_rcond={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
lowrank, rcond),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"lowrank": lowrank, "rcond": rcond, "rng_factory": rng_factory}
for lhs_shape, rhs_shape in [
((1, 1), (1, 1)),
((4, 6), (4,)),
((6, 6), (6, 1)),
((8, 6), (8, 4)),
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
]
for lowrank in [True, False]
for rcond in [-1, None, 0.5]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu") # SVD not implemented on TPU.
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
onp_fun = partial(np.linalg.lstsq, rcond=rcond)
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
tol = {np.float32: 1e-6, np.float64: 1e-12,
np.complex64: 1e-6, np.complex128: 1e-12}
def args_maker():
lhs = rng(lhs_shape, dtype)
if lowrank and lhs_shape[1] > 1:
lhs[:, -1] = lhs[:, :-1].mean(1)
return [lhs, rng(rhs_shape, dtype)]

self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)

jakevdp marked this conversation as resolved.
Show resolved Hide resolved
# Disabled because grad is flaky for low-rank inputs.
# TODO:
# jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)

# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
def testIssue669(self):
Expand Down