why set return_sequences=True and stateful=True for tf.keras.layers.LSTM?
Asked Answered
E

3

11

I am learning tensorflow2.0 and follow the tutorial. In the rnn example, I found the code:

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim, 
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.LSTM(rnn_units, 
                        return_sequences=True, 
                        stateful=True, 
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

My question is: why the code set the argument return_sequences=True and stateful=True? How about using the default argument?

Eucken answered 22/3, 2019 at 8:56 Comment(0)
O
20

The example in the tutorial is about text generation. This is the input that is fed to the network in a batch:

(64, 100, 65) # (batch_size, sequence_length, vocab_size)

  1. return_sequences=True

Since the intention is to predict a character for every time step i.e. for every character in the sequence, the next character needs to be predicted.

So, the argument return_sequences=True is set to true, to get an output shape of (64, 100, 65). If this argument is set to False, then only the last output would be returned, so for batch of 64, output would be (64, 65) i.e. for every sequence of 100 characters, only the last predicted character would be returned.

  1. stateful=True

From the documentation, "If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch."

In the below diagram from the tutorial, you can see that setting stateful helps the LSTM make better predictions by providing the context of the previous prediction.

enter image description here

Orts answered 22/3, 2019 at 11:6 Comment(4)
Can you cite the tutorial from which the image is taken? My understanding is that stateful=True share the context across batches, not the predictionsGoatskin
It is from the same tutorial mentioned in the question: tensorflow.org/tutorials/text/…Orts
@rtrtrt well each prediction is made at separate batches, like first batch will have first character for ex 4000 samples, next batch will have second character for next 4000 samplesPercentage
In (64, 100, 65) # (batch_size, sequence_length, vocab_size) shouldn't it be the number_of_features_per_time_step instead of vocab_size? As, when return_sequences=False it only returns the features of the last time step (for each example in a batch), instead of each of the 100 time steps [and the dimension becomes (64, 65)].Tarsometatarsus
P
17

Return Sequences

Lets look at a typical model architectures built using LSTMs.

Sequence to sequence models:

enter image description here

We feed in a sequence of inputs (x's), one batch at a time and each LSTM cell returns an output (y_i). So if your input is of size batch_size x time_steps X input_size then the LSTM output will be batch_size X time_steps X output_size. This is called a sequence to sequence model because an input sequence is converted into an output sequence. Typical usages of this model are in tagger (POS tagger, NER Tagger). In keras this is achieved by setting return_sequences=True.

Sequence classification - Many to one Architecture

enter image description here

In many to one architecture we use output sates of the only the last LSTM cell. This kind of architecture is normally used for classification problems like predicting if a movie review (represented as a sequence of words) is +ve of -ve. In keras if we set return_sequences=False the model returns the output state of only the last LSTM cell.

Stateful

An LSTM cell is composed of many gates as show in figure below from this blog post. The states/gates of the previous cell is used to calculate the state of the current cell. In keras if stateful=False then the states are reset after each batch. If stateful=True the states from the previous batch for index i will be used as initial state for index i in the next batch. So state information get propagated between batches with stateful=True. Check this link for explanation of usefulness of statefulness with an example.

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

Palaeozoic answered 22/3, 2019 at 11:40 Comment(2)
What does it mean "for index i"?Goatskin
@rtrtrt it means the LSTM cell unwrapped at time step iPalaeozoic
A
4

Let's see the differences when playing around with the arguments:

tf.keras.backend.clear_session()
tf.set_random_seed(42)
X = np.array([[[1,2,3],[4,5,6],[7,8,9]],[[1,2,3],[4,5,6],[0,0,0]]], dtype=np.float32)
model = tf.keras.Sequential([tf.keras.layers.LSTM(4, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform')])
print(tf.keras.backend.get_value(model(X)).shape)
# (2, 3, 4)
print(tf.keras.backend.get_value(model(X)))
# [[[-0.16141939  0.05600287  0.15932009  0.15656665]
#  [-0.10788933  0.          0.23865232  0.13983202]
   [-0.          0.          0.23865232  0.0057992 ]]

# [[-0.16141939  0.05600287  0.15932009  0.15656665]
#  [-0.10788933  0.          0.23865232  0.13983202]
#  [-0.07900514  0.07872108  0.06463861  0.29855606]]]

So, if return_sequences is set to True the model returned the full sequence it predicts.

tf.keras.backend.clear_session()
tf.set_random_seed(42)
model = tf.keras.Sequential([
tf.keras.layers.LSTM(4, return_sequences=False, stateful=True, recurrent_initializer='glorot_uniform')])
print(tf.keras.backend.get_value(model(X)).shape)
# (2, 4)
print(tf.keras.backend.get_value(model(X)))
# [[-0.          0.          0.23865232  0.0057992 ]
#  [-0.07900514  0.07872108  0.06463861  0.29855606]]

So, as the documentation states, if return_sequences is set to False, the model returns only the last output.

As for stateful it is a bit harder to dive into. But essentially, what it does is when having multiple batches of inputs, the last cell state at batch i will be the initial state at batch i+1. However, I think you will be more than fine going with the default settings. ​

Anglia answered 22/3, 2019 at 10:59 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.