I currently have the following code for a series of chained together RNNs in tensorflow. I am not using MultiRNN since I was to do something later on with the output of each layer.
for r in range(RNNS):
with tf.variable_scope('recurent_%d' % r) as scope:
state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
time_outputs = [None] * TIME_STEPS
for t in range(TIME_STEPS):
rnn_input = getTimeStep(rnn_outputs[r - 1], t)
time_outputs[t], state = rnn_func(rnn_input, state)
time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
scope.reuse_variables()
rnn_outputs[r] = tf.concat(1, time_outputs)
Currently I have a fixed number of time steps. However I would like to change it to have only one timestep but remember the state between batches. I would therefore need to create a state variable for each layer and assign it the final state of each of the layers. Something like this.
for r in range(RNNS):
with tf.variable_scope('recurent_%d' % r) as scope:
saved_state = tf.get_variable('saved_state', ...)
rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
saved_state = tf.assign(saved_state, state)
Then for each of the layers I would need to evaluate the saved state in my sess.run function as well as calling my training function. I would need to do this for every rnn layer. This seems like kind of a hassle. I would need to track every saved state and evaluate it in run. Also then run would need to copy the state from my GPU to host memory which would be inefficient and unnecessary. Is there a better way of doing this?
tf.nn.state_saving_rnn()
? – Alanalana