Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why use LSTMCell not use LSTM directly #181

Open
morestart opened this issue Aug 2, 2022 · 3 comments
Open

why use LSTMCell not use LSTM directly #181

morestart opened this issue Aug 2, 2022 · 3 comments

Comments

@morestart
Copy link

No description provided.

@AndreiMoraru123
Copy link

AndreiMoraru123 commented Jan 21, 2023

@morestart, you probably already know the answer. However, in case anyone else was wondering. LSTM in pytorch is a multi layer network, that is why you can select the number of layers. LSTMCell, on the other hand, is just a single cell. The author uses the latter here because of the way the attention has to be computed at each step in the training process. With a multilayer LSTM you could not do that, as the layer connections and forward pass are hard coded.

@thanhtvt
Copy link

thanhtvt commented Jan 30, 2023

@AndreiMoraru123 so if I set the number of layers in LSTM as 2, is it the same as I build a 2-time for-loop with LSTMCell?

@AndreiMoraru123
Copy link

@thanhtvt Exactly!

And this is precisely the example PyTorch provides in the docs:

If you take a look at the LSTM page:

rnn = nn.LSTM(10, 20, 2)   # (10 = input size, 20 = hidden size, 2 = this is the number of layers)
input = torch.randn(5, 3, 10)  # (5 = this is the sequence length, 3 = this is the batch size, 
#  10 = this is the last dimension, has to be equal to the input shape of the LSTM)
h0 = torch.randn(2, 3, 20) # (2 = here is the number of layers again, 3 = the batch size has to match,
#   20 = the hidden state has to match)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))  # the output here is going to be of size [5,3,20], just like the input

Then at the LSTMCell page, it's pretty much the same thing, but using a for loop:

rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
for i in range(input.size()[0]):
    hx, cx = rnn(input[i], (hx, cx))
    output.append(hx)
output = torch.stack(output, dim=0)   # output.size() will be [2,3,20], as you stacked the hx's [3,20] across the first dimension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants