Skip to content

Commit

Permalink
Seed the criterion tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Oct 3, 2023
1 parent 17b6b88 commit 53921f6
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/test_criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,16 @@ def _compute_grads(output, inputs):
return [input_.grad for input_ in inputs]


def test_infonce():
def _sample_dist_matrices(seed):
rng = torch.Generator().manual_seed(42)
pos_dist = torch.randn(100, generator=rng)
neg_dist = torch.randn(100, 100, generator=rng)
return pos_dist, neg_dist


pos_dist = torch.randn(100,)
neg_dist = torch.randn(100, 100)
@pytest.mark.parametrize("seed", [42, 4242, 424242])
def test_infonce(seed):
pos_dist, neg_dist = _sample_dist_matrices(seed)

ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
loss, align, uniform = cebra_criterions.infonce(pos_dist, neg_dist)
Expand All @@ -290,11 +296,9 @@ def test_infonce():
assert torch.allclose(align + uniform, loss)


def test_infonce_gradients():

rng = torch.Generator().manual_seed(42)
pos_dist = torch.randn(100, generator=rng)
neg_dist = torch.randn(100, 100, generator=rng)
@pytest.mark.parametrize("seed", [42, 4242, 424242])
def test_infonce_gradients(seed):
pos_dist, neg_dist = _sample_dist_matrices(seed)

for i in range(3):
pos_dist_ = pos_dist.clone()
Expand All @@ -312,7 +316,7 @@ def test_infonce_gradients():
grad = _compute_grads(loss, [pos_dist_, neg_dist_])

# NOTE(stes) default relative tolerance is 1e-5
assert torch.allclose(loss_ref, loss, rtol = 1e-4)
assert torch.allclose(loss_ref, loss, rtol=1e-4)

if i == 0:
assert grad[0] is not None
Expand Down

0 comments on commit 53921f6

Please sign in to comment.