Tensorflow, best way to save state in RNNs?
Asked Answered
M

3

15

I currently have the following code for a series of chained together RNNs in tensorflow. I am not using MultiRNN since I was to do something later on with the output of each layer.

 for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
        time_outputs = [None] * TIME_STEPS

        for t in range(TIME_STEPS):
            rnn_input = getTimeStep(rnn_outputs[r - 1], t)
            time_outputs[t], state = rnn_func(rnn_input, state)
            time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
            scope.reuse_variables()
        rnn_outputs[r] = tf.concat(1, time_outputs)

Currently I have a fixed number of time steps. However I would like to change it to have only one timestep but remember the state between batches. I would therefore need to create a state variable for each layer and assign it the final state of each of the layers. Something like this.

for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        saved_state = tf.get_variable('saved_state', ...)
        rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
        saved_state = tf.assign(saved_state, state)

Then for each of the layers I would need to evaluate the saved state in my sess.run function as well as calling my training function. I would need to do this for every rnn layer. This seems like kind of a hassle. I would need to track every saved state and evaluate it in run. Also then run would need to copy the state from my GPU to host memory which would be inefficient and unnecessary. Is there a better way of doing this?

Merrile answered 22/6, 2016 at 13:10 Comment(5)
Is this for prediction time? Why do you want to run one time step per state? Need more info to provide a useful answer.Rubeola
I think I figured it out using controll dependencies. I wanted to use it to generate a sequence.Merrile
For posterity, are you doing something like this? with tf.control_dependencies([tf.assign(saved_state, state)]): rnn_outputs[r] = tf.identity(rnn_outputs[r]) ?Rubeola
Yes that is what I am doingMerrile
Have you tried tf.nn.state_saving_rnn()?Alanalana
G
22

Here is the code to update the LSTM's initial state, when state_is_tuple=True by defining state variables. It also supports multiple layers.

We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run in order to update the state variables with the LSTM's last hidden state.

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

We can use that to update the LSTM's state after each batch. Note that I use tf.nn.dynamic_rnn for unrolling:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)

# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

The main difference to this answer is that state_is_tuple=True makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.

Resetting to zero

When using a trained model for prediction / decoding, you might want to reset the state to zero. Then, you can make use of this function:

def get_state_reset_op(state_variables, cell, batch_size):
    # Return an operation to set each variable in a list of LSTMStateTuples to zero
    zero_states = cell.zero_state(batch_size, tf.float32)
    return get_state_update_op(state_variables, zero_states)

For example like above:

reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})
Grogshop answered 20/12, 2016 at 10:33 Comment(5)
As I understand this code, it is intended for a stacked LSTM in tf. If I only have the one LSTM, then the only difference is that the functions don't need the for loops, right?Zaibatsu
@AndrewDraganov Yes, if you don't use MultiRNNCell, you don't need the for loops. cell.zero_state will return a LSTMStateTuple instead of a list of LSTMStateTuples.Grogshop
For training this works great! Thanks! However for prediction, would you do something like output, curr_state = sess.run([prediction, update_op], {data: ..})? and then unpack curr_state and concat it to data for the next iter of the prediction loop ... any insights on the mechanics of prediction following from your training solution would be brilliant! cheers!Mohair
@ruohoruotsi Good point! For prediction, you'll want to reset the state to zero for every new sample. Then, feed the sample to the model with update_op in your call. This way, the model will update its state and you don't have to concat the curr_state to data or anything. The model's state will be automatically updated.Grogshop
I could not make Resetting to zero work for some reasons. After I obtain model weights and biases, I want to test in on a test data, without batching, and has greater number of observations. However, Tensorflow throws an error about dimension mismatch (I understand why, it's because of concat function in LSTMCell).Occlusion
M
2

I am now saving the RNN states using the tf.control_dependencies. Here is an example.

 saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
        W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
        b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())

        rnn_output, states = rnn(last_output, saved_states)
        with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
            dense_input = tf.concat(1, (last_output, rnn_output))

        dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
        last_output = dense_output + last_output

I just make sure that part of my graph is dependent on saving the state.

Merrile answered 27/6, 2016 at 12:18 Comment(0)
G
2

These two links are also related and useful for this question:

https://github.com/tensorflow/tensorflow/issues/2695 https://github.com/tensorflow/tensorflow/issues/2838

Geodesy answered 12/7, 2016 at 9:37 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.