Skip to content

Commit

Permalink
chore: simplified tolerance values fot unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed Jan 31, 2025
1 parent 8864500 commit 4266a46
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 26 deletions.
6 changes: 2 additions & 4 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def model(labels):
mcmc.print_summary()
samples = mcmc.get_samples()
assert samples["logits"].shape == (num_samples, N)
# those coefficients are found by doing MAP inference using AutoDelta
expected_coefs = jnp.array([0.97, 2.05, 3.18])
assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.29)
assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.2)

if "JAX_ENABLE_X64" in os.environ:
assert samples["coefs"].dtype == jnp.float64
Expand Down Expand Up @@ -899,7 +897,7 @@ def test_get_proposal_loc_and_scale(dense_mass):
expected_loc = jnp.stack(expected_loc)
expected_scale = jnp.stack(expected_scale)
assert_allclose(actual_loc, expected_loc, rtol=1e-4)
assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.234)
assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.3)


@pytest.mark.parametrize("shape", [(4,), (3, 2)])
Expand Down
27 changes: 7 additions & 20 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,13 +1839,11 @@ def fn(*args):

eps = 1e-3
atol = 0.01
rtol = 0.01
rtol = 0.05
if jax_dist is dist.EulerMaruyama:
atol = 0.064
rtol = 0.042
elif jax_dist is dist.NegativeBinomialLogits:
atol = 0.013
rtol = 0.044

for i in range(len(params)):
if jax_dist is dist.EulerMaruyama and i == 1:
Expand Down Expand Up @@ -1945,12 +1943,8 @@ def test_mean_var(jax_dist, sp_dist, params):
if jnp.all(jnp.isfinite(sp_mean)):
assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
if jnp.all(jnp.isfinite(sp_var)):
rtol = 0.05
atol = 1e-2
if jax_dist is dist.InverseGamma:
rtol = 0.054
assert_allclose(
jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=rtol, atol=atol
jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.06, atol=1e-2
)
elif jax_dist in [dist.LKJ, dist.LKJCholesky]:
if jax_dist is dist.LKJCholesky:
Expand Down Expand Up @@ -1979,8 +1973,8 @@ def test_mean_var(jax_dist, sp_dist, params):
)
expected_std = expected_std * (1 - jnp.identity(dimension))

assert_allclose(jnp.mean(corr_samples, axis=0), expected_mean, atol=0.011)
assert_allclose(jnp.std(corr_samples, axis=0), expected_std, atol=0.011)
assert_allclose(jnp.mean(corr_samples, axis=0), expected_mean, atol=0.02)
assert_allclose(jnp.std(corr_samples, axis=0), expected_std, atol=0.02)
elif jax_dist in [dist.VonMises]:
# circular mean = sample mean
assert_allclose(d_jax.mean, jnp.mean(samples, 0), rtol=0.05, atol=1e-2)
Expand Down Expand Up @@ -2434,11 +2428,7 @@ def test_biject_to(constraint, shape):

# test inv
z = transform.inv(y)
atol = 1e-5
rtol = 1e-5
if constraint in [constraints.l1_ball]:
atol = 5e-5
assert_allclose(x, z, atol=atol, rtol=rtol)
assert_allclose(x, z, atol=1e-4, rtol=1e-5)

# test domain, currently all is constraints.real or constraints.real_vector
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))
Expand Down Expand Up @@ -2575,11 +2565,8 @@ def test_bijective_transforms(transform, event_shape, batch_shape):
else:
expected = jnp.log(jnp.abs(grad(transform)(x)))
inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))
atol = 1e-6
if isinstance(transform, transforms.ComposeTransform):
atol = 2.2e-6
assert_allclose(actual, expected, atol=atol)
assert_allclose(actual, -inv_expected, atol=atol)
assert_allclose(actual, expected, atol=1e-5)
assert_allclose(actual, -inv_expected, atol=1e-5)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef):
xxt = x[..., None] @ x[..., None, :]
expected = jnp.linalg.cholesky(A + coef * xxt)
actual = cholesky_update(jnp.linalg.cholesky(A), x, coef)
assert_allclose(actual, expected, atol=3.8e-4, rtol=8e-4)
assert_allclose(actual, expected, atol=1e-3, rtol=1e-3)


@pytest.mark.parametrize("n", [10, 100, 1000])
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_bijective_transforms(transform, shape):
if isinstance(transform, less_stable_transforms):
atol = 1e-2
elif isinstance(transform, (L1BallTransform, RecursiveLinearTransform)):
atol = 0.099
atol = 0.1
assert jnp.allclose(x1, x2, atol=atol)

log_abs_det_jacobian = transform.log_abs_det_jacobian(x1, y)
Expand Down

0 comments on commit 4266a46

Please sign in to comment.