Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasting in InfoNCE loss #86

Merged
merged 6 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions cebra/models/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,25 @@ def infonce(
"""InfoNCE implementation

See :py:class:`BaseInfoNCE` for reference.

Note:
- The behavior of this function changed beginning in CEBRA 0.3.0.
The InfoNCE implementation is numerically stabilized.
"""
with torch.no_grad():
c, _ = neg_dist.max(dim=1)
c, _ = neg_dist.max(dim=1, keepdim=True)
c = c.detach()
pos_dist = pos_dist - c

pos_dist = pos_dist - c.squeeze(1)
neg_dist = neg_dist - c
align = (-pos_dist).mean()
uniform = torch.logsumexp(neg_dist, dim=1).mean()
return align + uniform, align, uniform

c_mean = c.mean()
align_corrected = align - c_mean
uniform_corrected = uniform + c_mean

return align + uniform, align_corrected, uniform_corrected


class ContrastiveLoss(nn.Module):
Expand Down
123 changes: 121 additions & 2 deletions tests/test_criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,26 @@ def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
@torch.jit.script
def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
with torch.no_grad():
c, _ = neg_dist.max(dim=1)
c, _ = neg_dist.max(dim=1, keepdim=True)
c = c.detach()
pos_dist = pos_dist - c
pos_dist = pos_dist - c.squeeze(1)
neg_dist = neg_dist - c

align = (-pos_dist).mean()
uniform = torch.logsumexp(neg_dist, dim=1).mean()
return align + uniform, align, uniform


@torch.jit.script
def ref_infonce_not_stable(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
pos_dist = pos_dist
neg_dist = neg_dist

align = (-pos_dist).mean()
uniform = torch.logsumexp(neg_dist, dim=1).mean()
return align + uniform, align, uniform


class ReferenceInfoNCE(nn.Module):
"""The InfoNCE loss.
Attributes:
Expand Down Expand Up @@ -208,3 +218,112 @@ def test_infonce_reference_new_equivalence(temperature):
def test_alias():
assert cebra_criterions.InfoNCE == cebra_criterions.FixedCosineInfoNCE
assert cebra_criterions.InfoMSE == cebra_criterions.FixedEuclideanInfoNCE


def _reference_dot_similarity(ref, pos, neg):
pos_dist = torch.zeros(ref.shape[0])
neg_dist = torch.zeros(ref.shape[0], neg.shape[0])
for d in range(ref.shape[1]):
for i in range(len(ref)):
pos_dist[i] += ref[i, d] * pos[i, d]
for j in range(len(neg)):
neg_dist[i, j] += ref[i, d] * neg[j, d]
return pos_dist, neg_dist


def _reference_euclidean_similarity(ref, pos, neg):
pos_dist = torch.zeros(ref.shape[0])
neg_dist = torch.zeros(ref.shape[0], neg.shape[0])
for d in range(ref.shape[1]):
for i in range(len(ref)):
pos_dist[i] += -(ref[i, d] - pos[i, d])**2
for j in range(len(neg)):
neg_dist[i, j] += -(ref[i, d] - neg[j, d])**2
return pos_dist, neg_dist


def _reference_infonce(pos_dist, neg_dist):
align = -pos_dist.mean()
uniform = torch.logsumexp(neg_dist, dim=1).mean()
return align + uniform, align, uniform


def test_similiarities():

ref = torch.randn(10, 3)
pos = torch.randn(10, 3)
neg = torch.randn(12, 3)

pos_dist, neg_dist = _reference_dot_similarity(ref, pos, neg)
pos_dist_2, neg_dist_2 = cebra_criterions.dot_similarity(ref, pos, neg)

assert torch.allclose(pos_dist, pos_dist_2)
assert torch.allclose(neg_dist, neg_dist_2)

pos_dist, neg_dist = _reference_euclidean_similarity(ref, pos, neg)
pos_dist_2, neg_dist_2 = cebra_criterions.euclidean_similarity(
ref, pos, neg)

assert torch.allclose(pos_dist, pos_dist_2)
assert torch.allclose(neg_dist, neg_dist_2)


def _compute_grads(output, inputs):
for input_ in inputs:
input_.grad = None
assert input_.requires_grad
output.backward()
return [input_.grad for input_ in inputs]


def test_infonce():

pos_dist = torch.randn(100,)
neg_dist = torch.randn(100, 100)

ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
loss, align, uniform = cebra_criterions.infonce(pos_dist, neg_dist)

assert torch.allclose(ref_loss, loss)
assert torch.allclose(ref_align, align, atol=0.0001)
assert torch.allclose(ref_uniform, uniform)
assert torch.allclose(align + uniform, loss)


def test_infonce_gradients():

rng = torch.Generator().manual_seed(42)
pos_dist = torch.randn(100, generator=rng)
neg_dist = torch.randn(100, 100, generator=rng)

for i in range(3):
pos_dist_ = pos_dist.clone()
neg_dist_ = neg_dist.clone()
pos_dist_.requires_grad_(True)
neg_dist_.requires_grad_(True)
loss_ref = _reference_infonce(pos_dist_, neg_dist_)[i]
grad_ref = _compute_grads(loss_ref, [pos_dist_, neg_dist_])

pos_dist_ = pos_dist.clone()
neg_dist_ = neg_dist.clone()
pos_dist_.requires_grad_(True)
neg_dist_.requires_grad_(True)
loss = cebra_criterions.infonce(pos_dist_, neg_dist_)[i]
grad = _compute_grads(loss, [pos_dist_, neg_dist_])

# NOTE(stes) default relative tolerance is 1e-5
assert torch.allclose(loss_ref, loss, rtol = 1e-4)

if i == 0:
assert grad[0] is not None
assert grad[1] is not None
assert torch.allclose(grad_ref[0], grad[0])
assert torch.allclose(grad_ref[1], grad[1])
if i == 1:
assert grad[0] is not None
assert grad[1] is None
assert torch.allclose(grad_ref[0], grad[0])
if i == 2:
assert grad[0] is None
assert grad[1] is not None
assert torch.allclose(grad_ref[1], grad[1])
Loading