Tensorflow: How to pass output from previous time-step as input to next timestep
Asked Answered
A

4

14

It is a duplicate of this question How can I feed last output y(t-1) as input for generating y(t) in tensorflow RNN?

I want to pass the output of RNN at time-step T as the input at time-step T+1. input_RNN(T+1) = output_RNN(T) As per the documentation, the tf.nn.rnn as well as tf.nn.dynamic_rnn functions explicitly take the complete input to all time-steps.

I checked the seq2seq example at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/seq2seq.py It uses a loop and calls the cell(input,state) function. The cell can be lstm or gru or any other rnn cell. I checked the documentation to find the data type and shape of the arguments to cell(), but I found only the contructor of the form cell(num_neurons). I would like to know the correct way of passing output to input. I don't want to use other libraries/wrappers like keras built over tensorflow. Any suggestions?

Aesthetic answered 24/9, 2016 at 21:14 Comment(0)
R
3

One way to do this is to write your own RNN cell, together with your own Multi-RNN cell. This way you can internally store the output of the last RNN cell and just access it in the next time step. Check this blogpost for more info. You can also add e.g. encoder or decoders directly in the cell, so that you can process the data before feeding it to the cell or after retrieving it from the cell.

Another possibility is to use the function tf.nn.raw_rnn which lets you control what happens before and after the calls to the RNN cells. The following code snippet shows how to use this function, credits go to this article.

from tensorflow.python.ops.rnn import _transpose_batch_time
import tensorflow as tf


def sampling_rnn(self, cell, initial_state, input_, seq_lengths):

    # raw_rnn expects time major inputs as TensorArrays
    max_time = ...  # this is the max time step per batch
    inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
    inputs_ta = inputs_ta.unstack(_transpose_batch_time(input_))  # model_input is the input placeholder
    input_dim = input_.get_shape()[-1].value  # the dimensionality of the input to each time step
    output_dim = ...  # the dimensionality of the model's output at each time step

        def loop_fn(time, cell_output, cell_state, loop_state):
            """
            Loop function that allows to control input to the rnn cell and manipulate cell outputs.
            :param time: current time step
            :param cell_output: output from previous time step or None if time == 0
            :param cell_state: cell state from previous time step
            :param loop_state: custom loop state to share information between different iterations of this loop fn
            :return: tuple consisting of
              elements_finished: tensor of size [bach_size] which is True for sequences that have reached their end,
                needed because of variable sequence size
              next_input: input to next time step
              next_cell_state: cell state forwarded to next time step
              emit_output: The first return argument of raw_rnn. This is not necessarily the output of the RNN cell,
                but could e.g. be the output of a dense layer attached to the rnn layer.
              next_loop_state: loop state forwarded to the next time step
            """
            if cell_output is None:
                # time == 0, used for initialization before first call to cell
                next_cell_state = initial_state
                # the emit_output in this case tells TF how future emits look
                emit_output = tf.zeros([output_dim])
            else:
                # t > 0, called right after call to cell, i.e. cell_output is the output from time t-1.
                # here you can do whatever ou want with cell_output before assigning it to emit_output.
                # In this case, we don't do anything
                next_cell_state = cell_state
                emit_output = cell_output  

            # check which elements are finished
            elements_finished = (time >= seq_lengths)
            finished = tf.reduce_all(elements_finished)

            # assemble cell input for upcoming time step
            current_output = emit_output if cell_output is not None else None
            input_original = inputs_ta.read(time)  # tensor of shape (None, input_dim)

            if current_output is None:
                # this is the initial step, i.e. there is no output from a previous time step, what we feed here
                # can highly depend on the data. In this case we just assign the actual input in the first time step.
                next_in = input_original
            else:
                # time > 0, so just use previous output as next input
                # here you could do fancier things, whatever you want to do before passing the data into the rnn cell
                # if here you were to pass input_original than you would get the normal behaviour of dynamic_rnn
                next_in = current_output

            next_input = tf.cond(finished,
                                 lambda: tf.zeros([self.batch_size, input_dim], dtype=tf.float32),  # copy through zeros
                                 lambda: next_in)  # if not finished, feed the previous output as next input

            # set shape manually, otherwise it is not defined for the last dimensions
            next_input.set_shape([None, input_dim])

            # loop state not used in this example
            next_loop_state = None
            return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)

    outputs_ta, last_state, _ = tf.nn.raw_rnn(cell, loop_fn)
    outputs = _transpose_batch_time(outputs_ta.stack())
    final_state = last_state

    return outputs, final_state

As a side note: It is not clear if relying on the model's outputs during training is a good idea. Especially in the beginning, the outputs of the model can be quite bad, so your training might never converge or might not learn anything meaningful.

Reitz answered 14/3, 2018 at 9:40 Comment(9)
So in test time when I want to generate a sequence, technically I only need to provide the single initial data point right? since next_in = current_output is saying that the current timestep output is used as the next timestep input. (I want to confirm this because I'm noticing the result changes for inputs with the same initial time step data value but different subsequent data values, which are not supposed to matter since they are not needed for generating output)Bummer
In your code, inputs_ta is set to be a tensor array with size max_time, and check if elements are finsihed by time >= seq_lengths, then get input_original by inputs_ta.read(time), but shouldn't this give error? since when time=seq_lengths inputs.ta.read(seq_lengths) would be out of bound (I'm assuming the index of ta starts from 0)Bummer
@Bummer This error shouldn't happen because time goes from 0 to max_time-1, or in other words, each entry in seq_lengths is less or equal to max_time.Reitz
@Bummer I just noticed a potential problem: finished will only be true, when all elements in the batch have finished (because of the tf.reduce_all statement). This is something you might want to avoid, i.e. have a "finished" boolean per batch entry. It's possible to implement, but a bit more complicated. In general, in my experience, working with custom RNN cells is easier than with tf.raw_rnn.Reitz
sorry I still don't quite understand, max_time = ... # this is the max time step per batch, e.g. the max time step in a batch is T, then which means there is a sequence in the batch with seq_lengths=T (in fact I have all sequences in each batch as the same sequence length), in order to finish it has to satisfy the condition time >= seq_lengths, i.e. time has to get to max_time for the batch to finish. I agree with you on the second comment, do you know if there is any examples with custom RNN cells?Bummer
Maybe the inputs_ta.read(time) should happen inside tf.condBummer
@Bummer With the statement time >= seq_lengths we want to detect when we should not do anything for a particular batch entry. E.g. if an entry has length K < T, when T is the max time step, we must only do something in time steps 0 ... K-1, and at time steps >= K, there is nothing to do. In that sense, batch entries with length T are a special case. Because TF calls this loop fn only for time steps between 0 ... T-1 the statement time >= seq_lengths will never be executed for time == max_seq_length.Reitz
@Bummer An example with custom RNN cells is linked in the post. It is straight-forward to implement it, but if you have questions, may be consider opening a new post. Good luck!Reitz
Thanks, will check it outBummer
M
0

Define a init_state together with your network layers:

init_state = tf.placeholder(tf.float32, [batch_size,hidden])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden)
state_series, current_state = tf.nn.dynamic_rnn(basic_cell, x, dtype=tf.float32, initial_state = init_state)

Then outside you training_steps_loop initialize the zero-state:

 _init_state = np.zeros([batch_size,hidden], dtype=np.float32)

Inside your training_steps_loop run the session and put _init_state in your feed_dict and make the returned _current_state to you new _init_state for the next step:

_training_op, _state_series, _current_state = sess.run(
                [training_op, state_series, current_state],  feed_dict={x: xdb, y: ydb, init_state:_init_state})

_init_state = _current_state
Menell answered 23/3, 2018 at 9:59 Comment(1)
These code snippets do not directly solve OPs question. I think what you are trying to say is that you can do this iteratively for every time step in the sequence, i.e., manually copying over the internal state of the RNN from one step to the next. Albeit inefficient, this works at inference time, but not during training. You basically unroll the RNN manually, so you will lose all the gradients except for the one at the last time step and thus backprop through time won't give you the desired result.Reitz
D
0

I think one tricky way is using tf.contrib.seq2seq.InferenceHelper because this helper can just pass the output state to the next-time-step input as this issue and this question discuss. Here is my own code(inspired by this question) that works:

"""
construct Decoder
"""
cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))

# should use a start token both training and inferring process
start_tokens = tf.tile(tf.constant([START_ARRAY], dtype=tf.float32), [BATCH_SIZE, 1], name='start_tokens')

# training decoder
with tf.variable_scope("decoder"):
    # below construct a helper that pass output to next timestep
    training_helper = tf.contrib.seq2seq.InferenceHelper(
        sample_fn=lambda outputs: outputs,
        sample_shape=[decoder_hidden_units],
        sample_dtype=tf.float32,
        start_inputs=start_tokens,
        end_fn=lambda sample_ids: False)

    training_decoder = tf.contrib.seq2seq.BasicDecoder(cell, training_helper,
                                                       initial_state=cell.zero_state(dtype=tf.float32,
                                                                                     batch_size=[BATCH_SIZE]).
                                                       clone(cell_state=encoder_state))

    training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                      impute_finished=True,
                                                                      maximum_iterations=max_iters)

And the predicting version of decoder is identical to this training decoder, you can inference directly.

Depository answered 11/3, 2019 at 13:1 Comment(0)
T
0

Maybe not the fastest way but you could also use model.train_on_batch and predict these with predict_on_batch. Save the prediction for each batch and feed them back to your input. If your batch size is 1 you can feed y(t-1) back. You just have to loop through your dataset.

Twig answered 30/6, 2020 at 9:37 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.