-
Notifications
You must be signed in to change notification settings - Fork 0
/
rhino_phc_haus-4scale_swint_2xb2-36e_dotav2.py
executable file
·80 lines (75 loc) · 2.11 KB
/
rhino_phc_haus-4scale_swint_2xb2-36e_dotav2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
_base_ = './rhino-4scale_r50_2xb2-12e_dotav2.py'
max_epochs = 36
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=12)
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[27, 33],
gamma=0.1)
]
costs = [
dict(type='mmdet.FocalLossCost', weight=2.0),
dict(type='HausdorffCost', weight=5.0, box_format='xywha'),
dict(
type='GDCost',
loss_type='kld',
fun='log1p',
tau=1,
sqrt=False,
weight=5.0)
]
depths = [2, 2, 6, 2]
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa
model = dict(
version='v2',
backbone=dict(
_delete_=True,
type='mmdet.SwinTransformer',
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
_delete_=True,
type='mmdet.ChannelMapper',
in_channels=[192, 384, 768],
out_channels=256,
num_outs=4),
bbox_head=dict(
type='RHINOPositiveHungarianClassificationHead',
loss_iou=dict(
type='GDLoss',
loss_type='kld',
fun='log1p',
tau=1,
sqrt=False,
loss_weight=5.0)),
dn_cfg=dict(group_cfg=dict(max_num_groups=30)),
train_cfg=dict(
assigner=dict(match_costs=costs),
dn_assigner=dict(type='DNGroupHungarianAssigner', match_costs=costs),
))
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))