-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimageretrievalnet.py
145 lines (112 loc) · 3.92 KB
/
imageretrievalnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import pdb
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.parallel.data_parallel import DataParallel
from pooling import *
from normalization import *
from efficientnet_pytorch import EfficientNet
import timm
from models.wsdan import WSDAN
# import models as WSDAN
pool_dic = {
"GeM":GeM,
"SPoC":SPoC,
"MAC":MAC,
"RMAC":RMAC,
"GeMmp":GeMmp
}
class ImageRetrievaleffNet(nn.Module):
def __init__(self, net, pool):
super(ImageRetrievaleffNet, self).__init__()
self.net = net
self.norm = L2N()
self.pool = pool
def forward(self, x, test=False):
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
bs = x.size(0)
# Convolution layers
x = self.net.extract_features(x)
# Pooling and final linear layer
# x = self.net._avg_pooling(x)
x = self.pool(x)
o = x
x = x.view(bs, -1)
x = self.net._dropout(x)
x = self.net._fc(x)
o = self.norm(o).squeeze(-1).squeeze(-1)
return o, x
class ImageRetrievalresNet(nn.Module):
def __init__(self, features,fc_cls,pool):
super(ImageRetrievalresNet, self).__init__()
self.features = nn.Sequential(*features)
self.pool = pool
self.norm = L2N()
if type(fc_cls)==list:
self.fc_cls = nn.Sequential(*fc_cls)
else:
self.fc_cls=fc_cls
def forward(self, x, test=False):
o = self.features(x)
o = self.pool(o)
cls = self.fc_cls(o.squeeze())
o = self.norm(o).squeeze(-1).squeeze(-1)
return o, cls
class ImageRetrieval_WSDAN(nn.Module):
def __init__(self, net):
super(ImageRetrieval_WSDAN, self).__init__()
self.net = net
self.norm = L2N()
def forward(self, x, test=False):
cls, o, attention_map = self.net(x)
o = self.norm(o).squeeze(-1).squeeze(-1)
return o, cls
class VITImageRetrievalNet(nn.Module):
def __init__(self, net):
super(VITImageRetrievalNet, self).__init__()
self.net = net
self.norm = L2N()
def forward(self, x, test=False):
o = self.net.forward_features(x)
x = self.net.head(o)
o = self.norm(o).squeeze(-1).squeeze(-1)
return o, x
def image_net(net_name,opt):
if "R-" in opt.pool:
if opt.pool == 'R-ori':
pool = net.avgpool
else:
pool = pool_dic[opt.pool[2:]]()
pool = Rpool(pool)
else:
if opt.pool == 'ori':
pool = net.avgpool
else:
pool = pool_dic[opt.pool]()
if net_name == 'resnet101':
net = torchvision.models.resnet101(pretrained=True)
elif net_name == 'resnet50':
net = torchvision.models.resnet50(pretrained=True)
elif 'ibn' in net_name:
net = model = torch.hub.load('XingangPan/IBN-Net', net_name, pretrained=True)
elif net_name == 'WSDAN':
net = WSDAN(num_classes=opt.cls_num, M=32, net='inception_mixed_6e', pretrained=True)
return ImageRetrieval_WSDAN(net)
elif 'legacy' in net_name :
net = timm.create_model(net_name, pretrained = True)
elif 'vit' in net_name:
net = timm.create_model(net_name, pretrained = True)
net.head = nn.Linear(net.embed_dim, opt.cls_num)
return VITImageRetrievalNet(net)
elif 'efficient' in net_name:
net = EfficientNet.from_pretrained(net_name, num_classes=opt.cls_num)
return ImageRetrievaleffNet(net,pool)
else:
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
# pdb.set_trace()
features = list(net.children())[:-2]
fc_cls = nn.Linear(in_features=2048, out_features=opt.cls_num, bias=True)
return ImageRetrievalresNet(features,fc_cls,pool)