-
Notifications
You must be signed in to change notification settings - Fork 270
/
fsp.py
31 lines (23 loc) · 892 Bytes
/
fsp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class FSP(nn.Module):
'''
A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
'''
def __init__(self):
super(FSP, self).__init__()
def forward(self, fm_s1, fm_s2, fm_t1, fm_t2):
loss = F.mse_loss(self.fsp_matrix(fm_s1,fm_s2), self.fsp_matrix(fm_t1,fm_t2))
return loss
def fsp_matrix(self, fm1, fm2):
if fm1.size(2) > fm2.size(2):
fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)
fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2)
fsp = torch.bmm(fm1, fm2) / fm1.size(2)
return fsp