-
Notifications
You must be signed in to change notification settings - Fork 1
/
vit_seg_configs.py
130 lines (109 loc) · 4.2 KB
/
vit_seg_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import ml_collections
def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'seg'
config.representation_size = None
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
config.patch_size = 16
config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config
def get_testing():
"""Returns a minimal configuration for testing."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 1
config.transformer.num_heads = 1
config.transformer.num_layers = 1
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()
config.patches.grid = (16, 16)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
config.classifier = 'seg'
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.n_skip = 3
config.activation = 'softmax'
return config
def get_b32_config():
"""Returns the ViT-B/32 configuration."""
config = get_b16_config()
config.patches.size = (32, 32)
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
return config
def get_l16_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1024
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 4096
config.transformer.num_heads = 16
config.transformer.num_layers = 24
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.representation_size = None
# custom
config.classifier = 'seg'
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config
def get_r50_l16_config():
"""Returns the Resnet50 + ViT-L/16 configuration. customized """
config = get_l16_config()
config.patches.grid = (16, 16)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
config.classifier = 'seg'
config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.activation = 'softmax'
return config
def get_l32_config():
"""Returns the ViT-L/32 configuration."""
config = get_l16_config()
config.patches.size = (32, 32)
return config
def get_h14_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
config.hidden_size = 1280
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 5120
config.transformer.num_heads = 16
config.transformer.num_layers = 32
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config