Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1
in the example below. After each timestep the internal LSTM (memory and hidden) states need to be remembered for the next 'batch'. For the very beginning of the inference the internal LSTM states init_c, init_h
are computed given the input. These are then stored in a LSTMStateTuple
object which is passed to the LSTM. During training this state is updated every timestep. However for inference I want the state
to be saved in between batches, i.e. the initial states only need to be computed at the very beginning and after that the LSTM states should be saved after each 'batch' (n=1).
I found this related StackOverflow question: Tensorflow, best way to save state in RNNs?. However this only works if state_is_tuple=False
, but this behavior is soon to be deprecated by TensorFlow (see rnn_cell.py). Keras seems to have a nice wrapper to make stateful LSTMs possible but I don't know the best way to achieve this in TensorFlow. This issue on the TensorFlow GitHub is also related to my question: https://github.com/tensorflow/tensorflow/issues/2838
Anyone good suggestions for building a stateful LSTM model?
inputs = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs")
targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")
num_lstm_layers = 2
with tf.variable_scope("LSTM") as scope:
lstm_cell = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_tuple=True)
self.lstm = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_tuple=True)
init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'
init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'
self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers
outputs = []
for step in range(seq_length):
if step != 0:
scope.reuse_variables()
# CNN features, as input for LSTM
x_t = # ...
# LSTM step through time
output, self.state = self.lstm(x_t, self.state)
outputs.append(output)