-
Notifications
You must be signed in to change notification settings - Fork 717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
why use LSTMCell not use LSTM directly #181
Comments
@morestart, you probably already know the answer. However, in case anyone else was wondering. LSTM in pytorch is a multi layer network, that is why you can select the number of layers. LSTMCell, on the other hand, is just a single cell. The author uses the latter here because of the way the attention has to be computed at each step in the training process. With a multilayer LSTM you could not do that, as the layer connections and forward pass are hard coded. |
@AndreiMoraru123 so if I set the number of layers in LSTM as 2, is it the same as I build a 2-time for-loop with LSTMCell? |
@thanhtvt Exactly! And this is precisely the example PyTorch provides in the docs: If you take a look at the LSTM page: rnn = nn.LSTM(10, 20, 2) # (10 = input size, 20 = hidden size, 2 = this is the number of layers)
input = torch.randn(5, 3, 10) # (5 = this is the sequence length, 3 = this is the batch size,
# 10 = this is the last dimension, has to be equal to the input shape of the LSTM)
h0 = torch.randn(2, 3, 20) # (2 = here is the number of layers again, 3 = the batch size has to match,
# 20 = the hidden state has to match)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0)) # the output here is going to be of size [5,3,20], just like the input Then at the LSTMCell page, it's pretty much the same thing, but using a rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
for i in range(input.size()[0]):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
output = torch.stack(output, dim=0) # output.size() will be [2,3,20], as you stacked the hx's [3,20] across the first dimension. |
No description provided.
The text was updated successfully, but these errors were encountered: