Skip to content

Commit

Permalink
Update training augments config.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Oct 21, 2021
1 parent fbc7a86 commit 9a92334
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 27 deletions.
15 changes: 6 additions & 9 deletions configs/_base_/models/t2t-vit-t-14.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# model settings
embed_dims = 384
num_classes = 1000

model = dict(
type='ImageClassifier',
Expand All @@ -25,7 +26,7 @@
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
Expand All @@ -34,11 +35,7 @@
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(
cutmixup=dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
prob=1.0,
switch_prob=0.5,
mode='batch',
num_classes=1000)))
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))
15 changes: 6 additions & 9 deletions configs/_base_/models/t2t-vit-t-19.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# model settings
embed_dims = 448
num_classes = 1000

model = dict(
type='ImageClassifier',
Expand All @@ -25,7 +26,7 @@
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
Expand All @@ -34,11 +35,7 @@
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(
cutmixup=dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
prob=1.0,
switch_prob=0.5,
mode='batch',
num_classes=1000)))
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))
15 changes: 6 additions & 9 deletions configs/_base_/models/t2t-vit-t-24.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# model settings
embed_dims = 512
num_classes = 1000

model = dict(
type='ImageClassifier',
Expand All @@ -25,7 +26,7 @@
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
Expand All @@ -34,11 +35,7 @@
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(
cutmixup=dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
prob=1.0,
switch_prob=0.5,
mode='batch',
num_classes=1000)))
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

0 comments on commit 9a92334

Please sign in to comment.