Stateful LSTM: When to reset states?
Asked Answered
P

2

10

Given X with dimensions (m samples, n sequences, and k features), and y labels with dimensions (m samples, 0/1):

Suppose I want to train a stateful LSTM (going by keras definition, where "stateful = True" means that cell states are not reset between sequences per sample -- please correct me if I'm wrong!), are states supposed to be reset on a per epoch basis or per sample basis?

Example:

for e in epoch:
    for m in X.shape[0]:          #for each sample
        for n in X.shape[1]:      #for each sequence
            #train_on_batch for model...
            #model.reset_states()  (1) I believe this is 'stateful = False'?
        #model.reset_states()      (2) wouldn't this make more sense?
    #model.reset_states()          (3) This is what I usually see...

In summary, I am not sure if to reset states after each sequence or each epoch (after all m samples are trained in X).

Advice is much appreciated.

Puissant answered 10/8, 2017 at 21:8 Comment(0)
C
9

If you use stateful=True, you would typically reset the state at the end of each epoch, or every couple of samples. If you want to reset the state after each sample, then this would be equivalent to just using stateful=False.

Regarding the loops you provided:

for e in epoch:
    for m in X.shape[0]:          #for each sample
        for n in X.shape[1]:      #for each sequence

note that the dimension of X are not exactly

 (m samples, n sequences, k features)

The dimension is actually

(batch size, number of timesteps, number of features)

Hence, you are not supposed to have the inner loop:

for n in X.shape[1]

Now, regarding the loop

for m in X.shape[0]

since the enumeration over batches is done in keras automatically, you don't have to implement this loop as well (unless you want to reset the states every couple of samples). So if you want to reset only at the end of each epoch, you need only the external loop.

Here is an example of such architecture (taken from this blog post):

batch_size = 1
model = Sequential()
model.add(LSTM(16, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
for i in range(300):
    model.fit(X, y, epochs=1, batch_size=batch_size, verbose=2, shuffle=False)
    model.reset_states()
Corral answered 10/8, 2017 at 21:31 Comment(4)
So "stateful" actually finds relationships between samples, given sequence lengths (aka timesteps) and features... What if you have uneven number of sequence among samples? Guess I could pad X_train, but that could be costly on the smaller sequences. Thanks Miriam!Puissant
Yep, padding seems like a reasonable solution in such case. You could read about it more in the discussion here: github.com/fchollet/keras/issues/85Corral
how do you reset states in model.evaluate?Nertie
@MiriamFarber What kind of relationship does stateful find? Does statefulness in Keras different from the one used in RNNs(LSTM). Kindly, give some reference to learn about the relationship that Andy discussed.Brunet
V
1

Alternatively it seems possible to a custom callback. This avoids calling fit in a loop which is costly. Something similar to Tensorflow LSTM/GRU reset states once per epoch and not for each new batch

gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()
        
model.fit(train_dataset, validation_data=validation_dataset, \
    epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)
Valentinevalentino answered 21/2, 2022 at 20:57 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.