How to lay out training data with stateful LSTMs and batch_size > 1
Asked Answered
L

2

3

Background

I would like to do mini-batch training of "stateful" LSTMs in Keras. My input training data is in a large matrix "X" whose dimensions are m x n where

m = number-of-subsequences
n = number-of-time-steps-per-sequence

Each row of X contains a subsequence which picks up where the subsequence on the preceding row leaves off. So given a long sequence of data,

Data = ( t01, t02, t03, ... )

where "tK" means the token at position K in the original data, the sequence is layed out in X like so:

X = [
  t01 t02 t03 t04
  t05 t06 t07 t08
  t09 t10 t11 t12
  t13 t14 t15 t16
  t17 t18 t19 t20
  t21 t22 t23 t24
]

Question

My question is about what happens when I do mini-batch training on data layed out this way with stateful LSTMs. Specifically, mini-batch training typically trains on "contiguous" groups of rows at a time. So if I use a mini-batch size of 2, then X would be split into three mini-batches X1, X2 and X3 where

X1 = [
  t01 t02 t03 t04
  t05 t06 t07 t08
]

X2 = [
  t09 t10 t11 t12
  t13 t14 t15 t16
]

X3 = [
  t17 t18 t19 t20
  t21 t22 t23 t25
]

Notice that this type of mini-batching does not agree with training stateful LSTMs since the hidden states produced by processing the last column of the previous batch are not the hidden states that correspond to the time-step before the first column of the subsequent batch.

To see this, notice that the mini-batches will be processed as though from left-to-right like this:

------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t09 t10 t11 t12 | t17 t18 t19 t20
t05 t06 t07 t08 | t13 t14 t15 t16 | t21 t22 t23 t24

implying that

- Token t04 comes immediately before t09
- Token t08 comes immediately before t13
- Token t12 comes immediately before t17
- Token t16 comes immediately before t21

But I want mini-batching to group rows so that we get this kind of temporal alignment across mini-batches:

------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t05 t06 t07 t08 | t09 t10 t11 t12
t13 t14 t15 t16 | t17 t18 t19 t20 | t21 t22 t23 t24

What is the standard way to accomplish this when training LSTMs in Keras?

Thanks for any pointers here.

Lupita answered 2/2, 2018 at 7:15 Comment(0)
E
2

Solution 1 - Batch size = 1

Well, since it seems you actually have only one sequence (although divided, it's still one single sequence, right?), you have indeed to train with batch sizes equal to 1.

If you don't want to change or reorganize your data, just:

 X = X.reshape((-1,length,features))

     #where
         #length = 4 by your description    
         #features = 1 (if you have only one var over time, as it seems)

Solution 2 - Regroup for length = 8

Still using a batch size of 1, reshape your input data (before passing it to the model) so it has double length.

The final result will be exactly the same as if you were training with your described minibatches of size 2. (But be sure that you set the batch size to 1 in the input shape of your model, otherwise this will give you wrong results).

X = X.reshape((-1, 2 * length, features)) 

This will give you:

X = [
  [t01 t02 t03 t04 t05 t06 t07 t08]
  [t09 t10 t11 t12 t13 t14 t15 t16]
  [t17 t18 t19 t20 t21 t22 t23 t24]
]

Solution 3 - Only possible if you actually have two different sequences

By your description, it seems you have only one sequence. If you did have two different/independent sequences, you could then make a batch of size 2.

If splitting your sequence in two (and losing connection between them) is not a problem, you can rearrange your data:

X = X.reshape((2,-1,length, features))

Then:

X0 = X[:,0]
X1 = X[:,1]
...

You can try to group it in a single array:

X = X.reshape((2,-1,length, features))
X = np.swapaxes(X,0,1).reshape((-1,length,features))

Then:

X0 = X[0]
X1 = X[1]
...

And you may try to pass the complete X to training as long as you explicitly set the batch size to 2 in the model's input shape.

Eustache answered 2/2, 2018 at 13:42 Comment(6)
Solution #2 does not seem equivalent to what I am asking since it will process each subsequence (which is twice as long) in serial. But I want training to progress through multiple subsequences of the full sequence in parallel (i.e., as a "mini-batch").Lupita
In parallel? But then the model will never understand the data as a whole sequence. It will not have the states at the start of segment 2 to understand that it continues segment 1.Formfitting
My sequence is very long. Effectively 1MM timesteps (words from various articles) which I've concatenated into a single long sequence. So if I break this long sequence into 32 subsequences, each containing 1MM-divided-by-32 time steps, then I get 32 "independent" and parallel sequences that I can train on per batch. I don't want to use masking and break the sequence across document boundaries because the documents are very different in length. So I will be wasting a lot of GPU ops training on masked values. Please comment on my proposed answer if you have thoughts there. Thanks!Lupita
You will get many inconsistencies like that. You're telling the model that the beginning of one article is continuing the end of another article. You should really consider using a shape like (divisions * number_of_articles, length // divisions, features).Formfitting
I have a "beginning of new document" feature that signals when a token starts a new document. Hopefully the LSTM will learn to reset upon receiving this signal. Btw, this type of concatenation is common. See the PTB dataset used in this paper by Zaremba and Sutskever: arxiv.org/abs/1409.2329 .Lupita
In my opinion (personal), this only makes things more complicated. If you don't do that, you can train a lot of parallel batches.Formfitting
L
2

Thanks. It seems Daniel implies I can use a batch size of 2 if I reorganize X like this:

X = [
  t01 t02 t03 t04 
  t13 t14 t15 t16
  t05 t06 t07 t08
  t17 t18 t19 t20 
  t09 t10 t11 t12
  t21 t22 t23 t24
]

Is this a correct interpretation?

Lupita answered 2/2, 2018 at 18:23 Comment(2)
With the drawback that the model will not recognize t13 as following t12.Formfitting
Thanks. That's fine for me! In practice there will be few such discontinuities with the data I have.Lupita

© 2022 - 2024 — McMap. All rights reserved.