Background
I would like to do mini-batch training of "stateful" LSTMs in Keras. My input training data is in a large matrix "X" whose dimensions are m x n where
m = number-of-subsequences
n = number-of-time-steps-per-sequence
Each row of X contains a subsequence which picks up where the subsequence on the preceding row leaves off. So given a long sequence of data,
Data = ( t01, t02, t03, ... )
where "tK" means the token at position K in the original data, the sequence is layed out in X like so:
X = [
t01 t02 t03 t04
t05 t06 t07 t08
t09 t10 t11 t12
t13 t14 t15 t16
t17 t18 t19 t20
t21 t22 t23 t24
]
Question
My question is about what happens when I do mini-batch training on data layed out this way with stateful LSTMs. Specifically, mini-batch training typically trains on "contiguous" groups of rows at a time. So if I use a mini-batch size of 2, then X would be split into three mini-batches X1, X2 and X3 where
X1 = [
t01 t02 t03 t04
t05 t06 t07 t08
]
X2 = [
t09 t10 t11 t12
t13 t14 t15 t16
]
X3 = [
t17 t18 t19 t20
t21 t22 t23 t25
]
Notice that this type of mini-batching does not agree with training stateful LSTMs since the hidden states produced by processing the last column of the previous batch are not the hidden states that correspond to the time-step before the first column of the subsequent batch.
To see this, notice that the mini-batches will be processed as though from left-to-right like this:
------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t09 t10 t11 t12 | t17 t18 t19 t20
t05 t06 t07 t08 | t13 t14 t15 t16 | t21 t22 t23 t24
implying that
- Token t04 comes immediately before t09
- Token t08 comes immediately before t13
- Token t12 comes immediately before t17
- Token t16 comes immediately before t21
But I want mini-batching to group rows so that we get this kind of temporal alignment across mini-batches:
------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t05 t06 t07 t08 | t09 t10 t11 t12
t13 t14 t15 t16 | t17 t18 t19 t20 | t21 t22 t23 t24
What is the standard way to accomplish this when training LSTMs in Keras?
Thanks for any pointers here.