Skip to content

Commit

Permalink
test gradient for lstsq
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 17, 2020
1 parent 2ec71c0 commit 89cd7f7
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,10 @@ def args_maker():

self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)
# jtu.check_grads(jnp_fun, args_maker(), order=2, atol=tol, rtol=tol)

if np.finfo(dtype).bits == 64:
# Only check grad for first argument:
jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)

# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
Expand Down

0 comments on commit 89cd7f7

Please sign in to comment.