-
Notifications
You must be signed in to change notification settings - Fork 270
/
rkd.py
69 lines (50 loc) · 1.92 KB
/
rkd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
From https://github.com/lenscloth/RKD/blob/master/metric/loss.py
'''
class RKD(nn.Module):
'''
Relational Knowledge Distillation
https://arxiv.org/pdf/1904.05068.pdf
'''
def __init__(self, w_dist, w_angle):
super(RKD, self).__init__()
self.w_dist = w_dist
self.w_angle = w_angle
def forward(self, feat_s, feat_t):
loss = self.w_dist * self.rkd_dist(feat_s, feat_t) + \
self.w_angle * self.rkd_angle(feat_s, feat_t)
return loss
def rkd_dist(self, feat_s, feat_t):
feat_t_dist = self.pdist(feat_t, squared=False)
mean_feat_t_dist = feat_t_dist[feat_t_dist>0].mean()
feat_t_dist = feat_t_dist / mean_feat_t_dist
feat_s_dist = self.pdist(feat_s, squared=False)
mean_feat_s_dist = feat_s_dist[feat_s_dist>0].mean()
feat_s_dist = feat_s_dist / mean_feat_s_dist
loss = F.smooth_l1_loss(feat_s_dist, feat_t_dist)
return loss
def rkd_angle(self, feat_s, feat_t):
# N x C --> N x N x C
feat_t_vd = (feat_t.unsqueeze(0) - feat_t.unsqueeze(1))
norm_feat_t_vd = F.normalize(feat_t_vd, p=2, dim=2)
feat_t_angle = torch.bmm(norm_feat_t_vd, norm_feat_t_vd.transpose(1, 2)).view(-1)
feat_s_vd = (feat_s.unsqueeze(0) - feat_s.unsqueeze(1))
norm_feat_s_vd = F.normalize(feat_s_vd, p=2, dim=2)
feat_s_angle = torch.bmm(norm_feat_s_vd, norm_feat_s_vd.transpose(1, 2)).view(-1)
loss = F.smooth_l1_loss(feat_s_angle, feat_t_angle)
return loss
def pdist(self, feat, squared=False, eps=1e-12):
feat_square = feat.pow(2).sum(dim=1)
feat_prod = torch.mm(feat, feat.t())
feat_dist = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)
if not squared:
feat_dist = feat_dist.sqrt()
feat_dist = feat_dist.clone()
feat_dist[range(len(feat)), range(len(feat))] = 0
return feat_dist