From 53921f63f6f22ac7fb3fde475c96702c198efedc Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 3 Oct 2023 17:48:11 +0200 Subject: [PATCH] Seed the criterion tests --- tests/test_criterions.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/test_criterions.py b/tests/test_criterions.py index 7e809d9d..1a982860 100644 --- a/tests/test_criterions.py +++ b/tests/test_criterions.py @@ -276,10 +276,16 @@ def _compute_grads(output, inputs): return [input_.grad for input_ in inputs] -def test_infonce(): +def _sample_dist_matrices(seed): + rng = torch.Generator().manual_seed(42) + pos_dist = torch.randn(100, generator=rng) + neg_dist = torch.randn(100, 100, generator=rng) + return pos_dist, neg_dist + - pos_dist = torch.randn(100,) - neg_dist = torch.randn(100, 100) +@pytest.mark.parametrize("seed", [42, 4242, 424242]) +def test_infonce(seed): + pos_dist, neg_dist = _sample_dist_matrices(seed) ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist) loss, align, uniform = cebra_criterions.infonce(pos_dist, neg_dist) @@ -290,11 +296,9 @@ def test_infonce(): assert torch.allclose(align + uniform, loss) -def test_infonce_gradients(): - - rng = torch.Generator().manual_seed(42) - pos_dist = torch.randn(100, generator=rng) - neg_dist = torch.randn(100, 100, generator=rng) +@pytest.mark.parametrize("seed", [42, 4242, 424242]) +def test_infonce_gradients(seed): + pos_dist, neg_dist = _sample_dist_matrices(seed) for i in range(3): pos_dist_ = pos_dist.clone() @@ -312,7 +316,7 @@ def test_infonce_gradients(): grad = _compute_grads(loss, [pos_dist_, neg_dist_]) # NOTE(stes) default relative tolerance is 1e-5 - assert torch.allclose(loss_ref, loss, rtol = 1e-4) + assert torch.allclose(loss_ref, loss, rtol=1e-4) if i == 0: assert grad[0] is not None