-
Notifications
You must be signed in to change notification settings - Fork 0
/
TextNet.py
41 lines (35 loc) · 1.31 KB
/
TextNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch import nn
from torch.nn import functional as F
class TextNet(nn.Module):
def __init__(self, y_dim, bit, norm=True, mid_num1=1024*8, mid_num2=1024*8, hiden_layer=2):
"""
:param y_dim: dimension of tags
:param bit: bit number of the final binary code
"""
super(TextNet, self).__init__()
self.module_name = "txt_model"
mid_num1 = mid_num1 if hiden_layer > 1 else bit
modules = [nn.Linear(y_dim, mid_num1)]
if hiden_layer >= 2:
modules += [nn.ReLU(inplace=True)]
pre_num = mid_num1
for i in range(hiden_layer - 2):
if i == 0:
modules += [nn.Linear(mid_num1, mid_num2), nn.ReLU(inplace=True)]
else:
modules += [nn.Linear(mid_num2, mid_num2), nn.ReLU(inplace=True)]
pre_num = mid_num2
modules += [nn.Linear(pre_num, bit)]
self.fc = nn.Sequential(*modules)
self.norm = norm
def forward(self, x):
out1 = self.fc(x)
out = torch.tanh(out1)
if self.norm:
norm_x = torch.norm(out, dim=1, keepdim=True)
out = out / norm_x
return out1,out
def freeze_grad(self):
for p in self.parameters():
p.requires_grad = False