This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathlosses.py
127 lines (102 loc) · 3.73 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from logging import getLogger
import torch
import math
from src.utils import AllReduce
logger = getLogger()
def init_msn_loss(
num_views=1,
tau=0.1,
me_max=True,
return_preds=False
):
"""
Make unsupervised MSN loss
:num_views: number of anchor views
:param tau: cosine similarity temperature
:param me_max: whether to perform me-max regularization
:param return_preds: whether to return anchor predictions
"""
softmax = torch.nn.Softmax(dim=1)
def sharpen(p, T):
sharp_p = p**(1./T)
sharp_p /= torch.sum(sharp_p, dim=1, keepdim=True)
return sharp_p
def snn(query, supports, support_labels, temp=tau):
""" Soft Nearest Neighbours similarity classifier """
query = torch.nn.functional.normalize(query)
supports = torch.nn.functional.normalize(supports)
return softmax(query @ supports.T / temp) @ support_labels
def loss(
anchor_views,
target_views,
prototypes,
proto_labels,
T=0.25,
use_entropy=False,
use_sinkhorn=False,
sharpen=sharpen,
snn=snn
):
# Step 1: compute anchor predictions
probs = snn(anchor_views, prototypes, proto_labels)
# Step 2: compute targets for anchor predictions
with torch.no_grad():
targets = sharpen(snn(target_views, prototypes, proto_labels), T=T)
if use_sinkhorn:
targets = distributed_sinkhorn(targets)
targets = torch.cat([targets for _ in range(num_views)], dim=0)
# Step 3: compute cross-entropy loss H(targets, queries)
loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1))
# Step 4: compute me-max regularizer
rloss = 0.
if me_max:
avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
sloss = 0.
if use_entropy:
sloss = torch.mean(torch.sum(torch.log(probs**(-probs)), dim=1))
# -- logging
with torch.no_grad():
num_ps = float(len(set(targets.argmax(dim=1).tolist())))
max_t = targets.max(dim=1).values.mean()
min_t = targets.min(dim=1).values.mean()
log_dct = {'np': num_ps, 'max_t': max_t, 'min_t': min_t}
if return_preds:
return loss, rloss, sloss, log_dct, targets
return loss, rloss, sloss, log_dct
return loss
@torch.no_grad()
def distributed_sinkhorn(Q, num_itr=3, use_dist=True):
_got_dist = use_dist and torch.distributed.is_available() \
and torch.distributed.is_initialized() \
and (torch.distributed.get_world_size() > 1)
if _got_dist:
world_size = torch.distributed.get_world_size()
else:
world_size = 1
Q = Q.T
B = Q.shape[1] * world_size # number of samples to assign
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
if _got_dist:
torch.distributed.all_reduce(sum_Q)
Q /= sum_Q
for it in range(num_itr):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
if _got_dist:
torch.distributed.all_reduce(sum_of_rows)
Q /= sum_of_rows
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the colomns must sum to 1 so that Q is an assignment
return Q.T