Let me show some specific examples:
# LSTM example:
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
# LSTMCell example:
>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
The key difference:
- LSTM: the argument
2
, stands num_layers
, number of recurrent layers. There are seq_len * num_layers=5 * 2
cells. No loop but more cells.
- LSTMCell: in
for
loop (seq_len=5
times), each output of ith
instance will be input of (i+1)th
instance. There is only one cell, Truly Recurrent
If we set num_layers=1
in LSTM or add one more LSTMCell, the codes above will be the same.
Obviously, It is easier to apply parallel computing in LSTM.