From 1c267fbc9224b4af2be953535a9ac64e7d6d9443 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 Apr 2020 10:21:02 -0700 Subject: [PATCH 1/7] Initial implementation of np.linalg.lstsq() via SVD --- docs/jax.numpy.rst | 1 + jax/numpy/linalg.py | 43 +++++++++++++++++++++++++++++++++++++++++++ tests/linalg_test.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 5315b57af0ce..d865697b806f 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -313,6 +313,7 @@ jax.numpy.linalg eigvals eigvalsh inv + leastsq matrix_power matrix_rank multi_dot diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index f5f367260687..405df741d003 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -483,3 +483,46 @@ def solve(a, b): for func in get_module_functions(np.linalg): if func.__name__ not in globals(): globals()[func.__name__] = _not_implemented(func) + + +@_wraps(onp.linalg.lstsq, lax_description=textwrap.dedent("""\ + It has two important differences: + + 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and shows a deprecation + warning. In jax.numpy, the default rcond is `None`. + 2. In `np.linalg.lstsq` the residuals return an empty list for low-rank solutions. + Here, the residuals are returned in all cases, to make the function compatible + with jit. The non-jit compatible numpy behavior can be recovered by passing + numpy_resid=True. + """)) +def lstsq(a, b, rcond=None, *, numpy_resid=False): + # TODO: add lstsq to lax_linalg and implement this function via those wrappers. + a, b = _promote_arg_dtypes(a, b) + dtype = a.dtype + b_ndim = b.ndim + if b_ndim == 1: + b = b[:, None] + if a.ndim != 2: + raise TypeError( + f"{a.ndim}-dimensional array given. Array must be two-dimensional") + if b.ndim != 2: + raise TypeError( + f"{b_ndim}-dimensional array given. Array must be one or two-dimensional") + m, n = a.shape + if rcond is None: + rcond = np.finfo(dtype).eps * max(n, m) + elif rcond < 0: + rcond = np.finfo(dtype).eps + u, s, vt = svd(a, full_matrices=False) + mask = s >= rcond * s[0] + rank = mask.sum() + s_inv = np.where(mask, 1 / s, 0) + x = vt.conj().T @ (u.conj().T @ b * s_inv[:, None]) + resid = norm(b - a @ x, axis=0) ** 2 + # To match numpy, we would add the following condition. We don't do + # this because it makes lstsq incompatible with jit. + if numpy_resid and (rank < n or m <= n): + resid = np.asarray([]) + if b_ndim == 1: + x = x.ravel() + return x, resid, rank, s diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2f10bae203fd..5908fc803f8b 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -842,6 +842,37 @@ def testMultiDot(self, shapes, dtype, rng_factory): self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs={}_rhs={}_rcond={}".format( + jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + rcond), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "rcond": rcond, "rng_factory": rng_factory} + for lhs_shape, rhs_shape in [ + ((1, 1), (1, 1)), + ((4, 6), (4,)), + ((6, 6), (6, 1)), + ((8, 6), (8, 4)), + ] + for rcond in [-1, None, 0.5] + for dtype in float_types + complex_types + for rng_factory in [jtu.rand_default])) + @jtu.skip_on_devices("tpu") # SVD not implemented on TPU. + def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond, rng_factory): + rng = rng_factory() + _skip_if_unsupported_type(dtype) + onp_fun = partial(onp.linalg.lstsq, rcond=rcond) + jnp_fun = partial(np.linalg.lstsq, rcond=rcond) + jnp_fun_numpy = partial(np.linalg.lstsq, rcond=rcond, numpy_resid=True) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + tol = {onp.float32: 1e-6, onp.float64: 1e-12, + onp.complex64: 1e-6, onp.complex128: 1e-12} + + self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) + # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU. def testIssue669(self): From b576cb9b6ed88f70d56d6eb1d079a215aba1b8ac Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 Apr 2020 11:28:00 -0700 Subject: [PATCH 2/7] Fix docs reference --- docs/jax.numpy.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d865697b806f..b85a7ac26deb 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -313,7 +313,7 @@ jax.numpy.linalg eigvals eigvalsh inv - leastsq + lstsq matrix_power matrix_rank multi_dot From b741befe8d4600259d14ebdafee69821fc444a7c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 Apr 2020 23:11:05 -0700 Subject: [PATCH 3/7] Update based on review --- jax/numpy/linalg.py | 19 ++++++++++++------- tests/linalg_test.py | 16 +++++++++++----- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 405df741d003..f7fe5a7c6d6a 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -498,7 +498,8 @@ def solve(a, b): def lstsq(a, b, rcond=None, *, numpy_resid=False): # TODO: add lstsq to lax_linalg and implement this function via those wrappers. a, b = _promote_arg_dtypes(a, b) - dtype = a.dtype + if a.shape[0] != b.shape[0]: + raise ValueError("Leading dimensions of input arrays must match") b_ndim = b.ndim if b_ndim == 1: b = b[:, None] @@ -509,6 +510,7 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False): raise TypeError( f"{b_ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape + dtype = a.dtype if rcond is None: rcond = np.finfo(dtype).eps * max(n, m) elif rcond < 0: @@ -516,13 +518,16 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False): u, s, vt = svd(a, full_matrices=False) mask = s >= rcond * s[0] rank = mask.sum() - s_inv = np.where(mask, 1 / s, 0) - x = vt.conj().T @ (u.conj().T @ b * s_inv[:, None]) - resid = norm(b - a @ x, axis=0) ** 2 - # To match numpy, we would add the following condition. We don't do - # this because it makes lstsq incompatible with jit. + safe_s = np.where(mask, s, 1) + s_inv = np.where(mask, 1 / safe_s, 0)[:, np.newaxis] + uTb = np.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) + x = np.matmul(vt.conj().T, uTb * s_inv, precision=lax.Precision.HIGHEST) + # Numpy returns empty residuals in some cases. We return residuals if numpy_resid and (rank < n or m <= n): - resid = np.asarray([]) + resid = np.asarray([]) + else: + b_estimate = np.matmul(a, x, precision=lax.Precision.HIGHEST) + resid = norm(b - b_estimate, axis=0) ** 2 if b_ndim == 1: x = x.ravel() return x, resid, rank, s diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5908fc803f8b..54393ae64f5d 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -844,34 +844,40 @@ def testMultiDot(self, shapes, dtype, rng_factory): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": - "_lhs={}_rhs={}_rcond={}".format( + "_lhs={}_rhs={}_lowrank={}_rcond={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), - rcond), + lowrank, rcond), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "rcond": rcond, "rng_factory": rng_factory} + "lowrank": lowrank, "rcond": rcond, "rng_factory": rng_factory} for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 6), (4,)), ((6, 6), (6, 1)), ((8, 6), (8, 4)), ] + for lowrank in [True, False] for rcond in [-1, None, 0.5] for dtype in float_types + complex_types for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # SVD not implemented on TPU. - def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond, rng_factory): + def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) onp_fun = partial(onp.linalg.lstsq, rcond=rcond) jnp_fun = partial(np.linalg.lstsq, rcond=rcond) jnp_fun_numpy = partial(np.linalg.lstsq, rcond=rcond, numpy_resid=True) - args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] tol = {onp.float32: 1e-6, onp.float64: 1e-12, onp.complex64: 1e-6, onp.complex128: 1e-12} + def args_maker(): + lhs = rng(lhs_shape, dtype) + if lowrank and lhs_shape[1] > 1: + lhs[:, -1] = lhs[:, :-1].mean(1) + return [lhs, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) + # jtu.check_grads(jnp_fun, args_maker(), order=2, atol=tol, rtol=tol) # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU. From 0466a0a8d8985f82d39b42d268203706b99bddcb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 18 Apr 2020 08:33:31 -0700 Subject: [PATCH 4/7] test gradient for lstsq --- jax/numpy/linalg.py | 25 +++++++++++++------------ tests/linalg_test.py | 5 ++++- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index f7fe5a7c6d6a..9ccd939d480c 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -488,27 +488,27 @@ def solve(a, b): @_wraps(onp.linalg.lstsq, lax_description=textwrap.dedent("""\ It has two important differences: - 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and shows a deprecation - warning. In jax.numpy, the default rcond is `None`. - 2. In `np.linalg.lstsq` the residuals return an empty list for low-rank solutions. - Here, the residuals are returned in all cases, to make the function compatible - with jit. The non-jit compatible numpy behavior can be recovered by passing - numpy_resid=True. + 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future + the default will be `None`. Here, the default rcond is `None`. + 2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined + solutions. Here, the residuals are returned in all cases, to make the function + compatible with jit. The non-jit compatible numpy behavior can be recovered by + passing numpy_resid=True. """)) def lstsq(a, b, rcond=None, *, numpy_resid=False): # TODO: add lstsq to lax_linalg and implement this function via those wrappers. a, b = _promote_arg_dtypes(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") - b_ndim = b.ndim - if b_ndim == 1: + b_orig_ndim = b.ndim + if b_orig_ndim == 1: b = b[:, None] if a.ndim != 2: raise TypeError( f"{a.ndim}-dimensional array given. Array must be two-dimensional") if b.ndim != 2: raise TypeError( - f"{b_ndim}-dimensional array given. Array must be one or two-dimensional") + f"{b_original_ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype if rcond is None: @@ -521,13 +521,14 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False): safe_s = np.where(mask, s, 1) s_inv = np.where(mask, 1 / safe_s, 0)[:, np.newaxis] uTb = np.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) - x = np.matmul(vt.conj().T, uTb * s_inv, precision=lax.Precision.HIGHEST) - # Numpy returns empty residuals in some cases. We return residuals + x = np.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) + # Numpy returns empty residuals in some cases. To allow compilation, we + # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = np.asarray([]) else: b_estimate = np.matmul(a, x, precision=lax.Precision.HIGHEST) resid = norm(b - b_estimate, axis=0) ** 2 - if b_ndim == 1: + if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 54393ae64f5d..097974236b4c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -877,7 +877,10 @@ def args_maker(): self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) - # jtu.check_grads(jnp_fun, args_maker(), order=2, atol=tol, rtol=tol) + + if np.finfo(dtype).bits == 64: + # Only check grad for first argument: + jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2) # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU. From 727d665e9aa9ce469fcf322bc00f8a3a7e0f8088 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 4 May 2020 11:59:34 -0700 Subject: [PATCH 5/7] Disable gradient test and mark as TODO --- jax/numpy/linalg.py | 1 + tests/linalg_test.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 9ccd939d480c..d0aee27f3178 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -497,6 +497,7 @@ def solve(a, b): """)) def lstsq(a, b, rcond=None, *, numpy_resid=False): # TODO: add lstsq to lax_linalg and implement this function via those wrappers. + # TODO: add custom jvp rule for more robust lstsq differentiation a, b = _promote_arg_dtypes(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 097974236b4c..e5c4034449d9 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -878,9 +878,9 @@ def args_maker(): self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) - if np.finfo(dtype).bits == 64: - # Only check grad for first argument: - jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2) + # Disabled because grad is flaky for low-rank inputs. + # TODO: + # jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2) # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU. From cf61bed6f1629fb385c86c7f878e709b6b000834 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 May 2020 09:34:39 -0700 Subject: [PATCH 6/7] fix issues after rebase on master --- jax/numpy/linalg.py | 18 +++++++++--------- tests/linalg_test.py | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index d0aee27f3178..f10e97f8abde 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -485,7 +485,7 @@ def solve(a, b): globals()[func.__name__] = _not_implemented(func) -@_wraps(onp.linalg.lstsq, lax_description=textwrap.dedent("""\ +@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\ It has two important differences: 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future @@ -513,22 +513,22 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False): m, n = a.shape dtype = a.dtype if rcond is None: - rcond = np.finfo(dtype).eps * max(n, m) + rcond = jnp.finfo(dtype).eps * max(n, m) elif rcond < 0: - rcond = np.finfo(dtype).eps + rcond = jnp.finfo(dtype).eps u, s, vt = svd(a, full_matrices=False) mask = s >= rcond * s[0] rank = mask.sum() - safe_s = np.where(mask, s, 1) - s_inv = np.where(mask, 1 / safe_s, 0)[:, np.newaxis] - uTb = np.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) - x = np.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) + safe_s = jnp.where(mask, s, 1) + s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] + uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) + x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): - resid = np.asarray([]) + resid = jnp.asarray([]) else: - b_estimate = np.matmul(a, x, precision=lax.Precision.HIGHEST) + b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) resid = norm(b - b_estimate, axis=0) ** 2 if b_orig_ndim == 1: x = x.ravel() diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e5c4034449d9..dc3dd41edc6a 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -862,20 +862,20 @@ def testMultiDot(self, shapes, dtype, rng_factory): for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # SVD not implemented on TPU. def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory): - rng = rng_factory() + rng = rng_factory(self.rng()) _skip_if_unsupported_type(dtype) - onp_fun = partial(onp.linalg.lstsq, rcond=rcond) - jnp_fun = partial(np.linalg.lstsq, rcond=rcond) - jnp_fun_numpy = partial(np.linalg.lstsq, rcond=rcond, numpy_resid=True) - tol = {onp.float32: 1e-6, onp.float64: 1e-12, - onp.complex64: 1e-6, onp.complex128: 1e-12} + onp_fun = partial(np.linalg.lstsq, rcond=rcond) + jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond) + jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True) + tol = {np.float32: 1e-6, np.float64: 1e-12, + np.complex64: 1e-6, np.complex128: 1e-12} def args_maker(): lhs = rng(lhs_shape, dtype) if lowrank and lhs_shape[1] > 1: lhs[:, -1] = lhs[:, :-1].mean(1) return [lhs, rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol) + self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol) # Disabled because grad is flaky for low-rank inputs. From 0012454f1d3531430aac24fa416df1d80b6688b4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 11 May 2020 12:25:59 -0700 Subject: [PATCH 7/7] Mention JVP issue in lstsq docstring. --- jax/numpy/linalg.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index f10e97f8abde..12ab5682409d 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -494,6 +494,9 @@ def solve(a, b): solutions. Here, the residuals are returned in all cases, to make the function compatible with jit. The non-jit compatible numpy behavior can be recovered by passing numpy_resid=True. + + The lstsq function does not currently have a custom JVP rule, so the gradient is + poorly behaved for some inputs, particularly for low-rank `a`. """)) def lstsq(a, b, rcond=None, *, numpy_resid=False): # TODO: add lstsq to lax_linalg and implement this function via those wrappers.