-
Notifications
You must be signed in to change notification settings - Fork 18
/
res50_distill_mv1_img.py
83 lines (78 loc) · 3.17 KB
/
res50_distill_mv1_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
_base_ = [
'../../mobilenet_v1/mobilenet_v1.py'
]
# model settings
find_unused_parameters = True
# distillation settings
use_logit = True
# config settings
srrl = False
mgd = False
wsld = False
dkd = False
kd = False
nkd = True
# method details
distiller = dict(
type='ClassificationDistiller',
teacher_pretrained = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
use_logit = use_logit,
distill_cfg = [ dict(methods=[dict(type='SRRLLoss',
name='loss_srrl',
use_this = srrl,
student_channels = 1024,
teacher_channels = 2048,
alpha=1.0,
beta=1.0,
)
]
),
dict(methods=[dict(type='MGDLoss',
name='loss_mgd',
use_this = mgd,
student_channels = 1024,
teacher_channels = 2048,
alpha_mgd=0.00007,
lambda_mgd=0.15,
)
]
),
dict(methods=[dict(type='WSLDLoss',
name='loss_wsld',
use_this = wsld,
temp=2.0,
alpha=2.5,
num_classes=1000,
)
]
),
dict(methods=[dict(type='DKDLoss',
name='loss_dkd',
use_this = dkd,
temp=1.0,
alpha=1.0,
beta=0.5,
)
]
),
dict(methods=[dict(type='NKDLoss',
name='loss_nkd',
use_this = nkd,
temp=1.0,
gamma=1.5,
)
]
),
dict(methods=[dict(type='KDLoss',
name='loss_kd',
use_this = kd,
temp=1.0,
alpha=0.5,
)
]
),
]
)
student_cfg = 'configs/mobilenet_v1/mobilenet_v1.py'
teacher_cfg = 'configs/resnet/resnet50_b32x8_imagenet.py'
optimizer_config = dict(_delete_=True,grad_clip=dict(max_norm=5.0))