-
Notifications
You must be signed in to change notification settings - Fork 30
/
util.py
37 lines (26 loc) · 1.16 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.nn as nn
import collections
def embedding_size_from_name(name):
return int(name.strip().split('.')[-1][:-1])
def print_dim(name, tensor):
print("%s -> %s" % (name, tensor.size()))
class RNNWrapper(nn.Module):
"""
Wrapper around GRU or LSTM RNN. If underlying RNN is GRU, this wrapper does nothing, it just forwards inputs and
outputs. If underlying RNN is LSTM this wrapper ignores LSTM cell state (s) and returns just hidden state (h).
This wrapper allows us to unify interface for GRU and LSTM so we don't have to treat them differently.
"""
LSTM = 'LSTM'
GRU = 'GRU'
def __init__(self, rnn):
super(RNNWrapper, self).__init__()
assert isinstance(rnn, nn.LSTM) or isinstance(rnn, nn.GRU)
self.rnn_type = self.LSTM if isinstance(rnn, nn.LSTM) else self.GRU
self.rnn = rnn
def forward(self, *input):
rnn_out, hidden = self.rnn(*input)
if self.rnn_type == self.LSTM:
hidden, s = hidden # ignore LSTM cell state s
return rnn_out, hidden
# Metadata used to describe dataset
Metadata = collections.namedtuple('Metadata', 'vocab_size padding_idx vectors')