-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_net.py
146 lines (126 loc) · 4.65 KB
/
text_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Module for natural language classification.
"""
import time
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
class TextClassificationNet(nn.Module):
def __init__(
self,
n_words,
word_embed_dim,
encoder_dim,
n_enc_layers,
dpout_model,
dpout_fc,
fc_dim,
bsize,
n_classes,
pool_type,
linear_fc,
bidirectional,
rnn,
double_in,
):
super(TextClassificationNet, self).__init__()
# Store settings.
self.encoder_dim = encoder_dim
self.n_enc_layers = n_enc_layers
self.dpout_fc = dpout_fc
self.fc_dim = fc_dim
self.n_classes = n_classes
self.linear_fc = linear_fc
self.bidirectional = bidirectional
self.rnn = rnn
self.double_in = double_in
# Construct encoder and classifier.
self.encoder = RecurrentEncoder(
n_enc_layers, word_embed_dim, encoder_dim, pool_type, dpout_model, bsize, bidirectional, rnn
)
feature_multiplier = 4 if self.double_in else 1
self.inputdim = feature_multiplier * self.encoder_dim
if self.bidirectional:
self.inputdim *= 2
if self.linear_fc:
self.classifier = nn.Sequential(
nn.Linear(self.inputdim, self.fc_dim),
nn.Linear(self.fc_dim, self.fc_dim),
nn.Linear(self.fc_dim, self.n_classes)
)
else:
self.classifier = nn.Sequential(
nn.Dropout(p=self.dpout_fc),
nn.Linear(self.inputdim, self.fc_dim),
nn.Tanh(),
nn.Dropout(p=self.dpout_fc),
nn.Linear(self.fc_dim, self.fc_dim),
nn.Tanh(),
nn.Dropout(p=self.dpout_fc),
nn.Linear(self.fc_dim, self.n_classes),
)
def forward(self, s1, s2=None):
if s2 is not None:
u = self.encoder(s1)
v = self.encoder(s2)
features = torch.cat((u, v, torch.abs(u-v), u*v), 1)
else:
features = self.encoder(s1)
output = self.classifier(features)
return output
class RecurrentEncoder(nn.Module):
def __init__(
self, n_enc_layers, word_embed_dim, encoder_dim, pool_type, dpout_model, bsize, bidirectional, rnn
):
super(RecurrentEncoder, self).__init__()
self.bsize = bsize
self.n_enc_layers = n_enc_layers
self.word_embed_dim = word_embed_dim
self.encoder_dim = encoder_dim
self.pool_type = pool_type
self.dpout_model = dpout_model
self.bidirectional = bidirectional
self.rnn = rnn
net_cls = nn.RNN if self.rnn else nn.LSTM
self.encoder = net_cls(
self.word_embed_dim,
self.encoder_dim,
self.n_enc_layers,
bidirectional=bidirectional,
dropout=self.dpout_model,
)
def is_cuda(self):
# either all weights are on cpu or they are on gpu
return 'cuda' in str(self.encoder.bias_hh_l0.device)
def forward(self, sent_tuple):
# sent_len: [max_len, ..., min_len] (bsize)
# sent: Variable(seqlen x bsize x worddim)
sent, sent_len = sent_tuple
self.encoder.flatten_parameters()
# Sort by length (keep idx)
sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
idx_unsort = np.argsort(idx_sort)
sent_len, idx_sort = sent_len.copy(), idx_sort.copy()
idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \
else torch.from_numpy(idx_sort)
sent = sent.index_select(1, Variable(idx_sort))
# Handling padding in Recurrent Networks
sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len)
sent_output = self.encoder(sent_packed)[0] # seqlen x batch x 2*nhid
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]
# Un-sort by length
idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \
else torch.from_numpy(idx_unsort)
sent_output = sent_output.index_select(1, Variable(idx_unsort))
# Pooling
if self.pool_type == "mean":
sent_len = Variable(torch.FloatTensor(sent_len)).unsqueeze(1).cuda()
emb = torch.sum(sent_output, 0).squeeze(0)
emb = emb / sent_len.expand_as(emb)
elif self.pool_type == "max":
emb = torch.max(sent_output, 0)[0]
if emb.ndimension() == 3:
emb = emb.squeeze(0)
assert emb.ndimension() == 2
return emb