-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathmodels_vit_group_channels.py
138 lines (106 loc) · 5.35 KB
/
models_vit_group_channels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import timm.models.vision_transformer
from timm.models.vision_transformer import PatchEmbed
from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid
class GroupChannelsVisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
"""
def __init__(self, global_pool=True, channel_embed=256,
channel_groups=((0, 1, 2, 6), (3, 4, 5, 7), (8, 9)), **kwargs):
super().__init__(**kwargs)
img_size = kwargs['img_size']
patch_size = kwargs['patch_size']
in_c = kwargs['in_chans']
embed_dim = kwargs['embed_dim']
self.channel_groups = channel_groups
self.patch_embed = nn.ModuleList([PatchEmbed(img_size, patch_size, len(group), embed_dim)
for group in channel_groups])
# self.patch_embed = PatchEmbed(img_size, patch_size, 1, embed_dim)
num_patches = self.patch_embed[0].num_patches
# Positional and channel embed
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim - channel_embed))
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(num_patches ** .5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
num_groups = len(channel_groups)
self.channel_embed = nn.Parameter(torch.zeros(1, num_groups, channel_embed))
chan_embed = get_1d_sincos_pos_embed_from_grid(self.channel_embed.shape[-1], torch.arange(num_groups).numpy())
self.channel_embed.data.copy_(torch.from_numpy(chan_embed).float().unsqueeze(0))
# Extra embedding for cls to fill embed_dim
self.channel_cls_embed = nn.Parameter(torch.zeros(1, 1, channel_embed))
channel_cls_embed = torch.zeros((1, channel_embed))
self.channel_cls_embed.data.copy_(channel_cls_embed.float().unsqueeze(0))
self.global_pool = global_pool
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
def forward_features(self, x):
b, c, h, w = x.shape
x_c_embed = []
for i, group in enumerate(self.channel_groups):
x_c = x[:, group, :, :]
x_c_embed.append(self.patch_embed[i](x_c)) # (N, L, D)
x = torch.stack(x_c_embed, dim=1) # (N, G, L, D)
_, G, L, D = x.shape
# add channel embed
channel_embed = self.channel_embed.unsqueeze(2) # (1, c, 1, cD)
pos_embed = self.pos_embed[:, 1:, :].unsqueeze(1) # (1, 1, L, pD)
# Channel embed same across (x,y) position, and pos embed same across channel (c)
channel_embed = channel_embed.expand(-1, -1, pos_embed.shape[2], -1) # (1, c, L, cD)
pos_embed = pos_embed.expand(-1, channel_embed.shape[1], -1, -1) # (1, c, L, pD)
pos_channel = torch.cat((pos_embed, channel_embed), dim=-1) # (1, c, L, D)
# add pos embed w/o cls token
x = x + pos_channel # (N, G, L, D)
x = x.view(b, -1, D) # (N, G*L, D)
cls_pos_channel = torch.cat((self.pos_embed[:, :1, :], self.channel_cls_embed), dim=-1) # (1, 1, D)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = cls_pos_channel + self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # (N, 1 + c*L, D)
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
def vit_base_patch16(**kwargs):
model = GroupChannelsVisionTransformer(
channel_embed=256, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_large_patch16(**kwargs):
model = GroupChannelsVisionTransformer(
channel_embed=256, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_huge_patch14(**kwargs):
model = GroupChannelsVisionTransformer(
embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
if __name__ == '__main__':
# input = torch.rand(1, 12, 96, 96)
#
# model = vit_base_patch16(img_size=96,patch_size=8,in_chans=12)
# output = model(input)
# print(output.shape)
# from fvcore.nn import FlopCountAnalysis, parameter_count_table
# model = resnet101()
tensor = torch.rand(1, 12, 128, 128)
model = vit_base_patch16(img_size=128,patch_size=8,in_chans=12)
output = model(tensor)
print(output.shape)
# flops = FlopCountAnalysis(model, tensor)
# print(flops.total())
# print(parameter_count_table(model))