How to extract cell state from a LSTM at each timestep in Keras?
Asked Answered
B

4

12

Is there a way in Keras to retrieve the cell state (i.e., c vector) of a LSTM layer at every timestep of a given input?

It seems the return_state argument returns the last cell state after the computation is done, but I need also the intermediate ones. Also, I don't want to pass these cell states to the next layer, I only want to be able to access them.

Preferably using TensorFlow as backend.

Thanks

Beaux answered 27/8, 2018 at 3:9 Comment(3)
Did you manage to find a solution for this? I am looking at the exact same problem at the moment.Brainard
I did not find an easy and intuitive way to do that. However, if you create a model with the LSTM layer as the only layer in the model (just copying the weights) and set return_state to true you can get the last cell state produced by the sequence. So you can just process the sequence until a given timestep to get the cell state produced by that timestep.Beaux
For instance, if your sequence originally has 100 timesteps but you want to know the cell state after timestep 40, you just remove the last 60 timesteps and run the new sequence through the layer. It's a very lame solution, but the only one I think that can work. I didn't try though because I changed my approach in the project I was working with.Beaux
C
4

I was looking for a solution to this issue and after reading the guidance for creating your own custom RNN Cell in tf.keras (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AbstractRNNCell), I believe the following is the most concise and easy to read way of doing this for Tensorflow 2:

import tensorflow as tf
from tensorflow.keras.layers import LSTMCell

class LSTMCellReturnCellState(LSTMCell):

    def call(self, inputs, states, training=None):
        real_inputs = inputs[:,:self.units] # decouple [h, c]
        outputs, [h,c] = super().call(real_inputs, states, training=training)
        return tf.concat([h, c], axis=1), [h,c]



num_units = 512
test_input = tf.random.uniform([5,100,num_units])

rnn = tf.keras.layers.RNN(LSTMCellReturnCellState(num_units),
                          return_sequences=True, return_state=True)

whole_seq_output, final_memory_state, final_carry_state = rnn(test_input)

print(whole_seq_output.shape)
>>> (5,100,1024)

# Hidden state sequence
h_seq = whole_seq_output[:,:,:num_units] # (5,100,512)

# Cell state sequence
c_seq = whole_seq_output[:,:,num_units:] # (5,100,512)

As mentioned in an above solution, you can see the advantage of this is that it can be easily wrapped into tf.keras.layers.RNN as a drop-in for the normal LSTMCell.

Here is a Colab Notebook with the code running as expected for tensorflow==2.6.0

Contracture answered 28/9, 2021 at 15:29 Comment(0)
C
2

I know it's pretty late, I hope this can help.

what you are asking, technically, is possible by modifying the LSTM-cell in call method. I modify it and make it return 4 dimension instead of 3 when you give return_sequences=True.

Code

from keras.layers.recurrent import _generate_dropout_mask
class Mod_LSTMCELL(LSTMCell):
    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

            # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return tf.expand_dims(tf.concat([h,c],axis=0),0), [h, c]

Sample code

# create a cell
test = Mod_LSTMCELL(100)

# Input timesteps=10, features=7
in1 = Input(shape=(10,7))
out1 = RNN(test, return_sequences=True)(in1)

M = Model(inputs=[in1],outputs=[out1])
M.compile(keras.optimizers.Adam(),loss='mse')

ans = M.predict(np.arange(7*10,dtype=np.float32).reshape(1, 10, 7))

print(ans.shape)
# state_h
print(ans[0,0,0,:])
# state_c
print(ans[0,0,1,:])

Caltrop answered 7/3, 2020 at 19:51 Comment(7)
This solution is extremely complicated and unnecessary. See my answer for the correct way to do this, which does not require subclassing LSTMCellFlightless
unnecessary? , I believe your answer is not applicable for integrating with Keras API, because my technique just create a custom cell so that you can just use it with existing API. For example, you can use with model.add() for sequential model. Moreover, this method also create a tensor that include all of the value h,c which you may use it for building huge model without implementing concat layer every time when we use it, which result in cleaner and much easier to read code. This is not complicate, you just copy my class and use it as you use normal LSTM cell with RNN.Caltrop
You wrote a custom class to return the [h,c] states when LSTMCell literally already does this. Why are you adding unnecessary complexity to implement a feature that LSTMCell already has. Moreover, your code doesn't even return the activations. OK - your code can be used with sequential API. My code is the 'correct' way to do this and should be used with functional api or subclassed tf.keras.Model.Flightless
I believe you haven't looked at any keras source, my implementation is correct based on keras code. Most of the code remain the same while I simply modify some of the code to return the C state, so you are clearly wrong on adding complexity because my code is just a backend of what actual LSTM doing. Next, If your model cannot implementing on keras API then what is the point of using high level API like Keras, we can just implement all of this using pure low level tenserflow then.Caltrop
Next, coming to your implementation complexity, it does not using keras layer class which make you code more ugly when you try to visualize the computaiton graph and debug while Keras have layer class for helping visualizing the graph, therefore my implementation may look messy but this is much better not just visualization but also can integrate with other keras API seamlessly while your code is just a low level implementation that have no scalability and low modality which result in poor reusability of your code which I believe make your implementation extremely complicated instead of mine.Caltrop
Next coming to your question, the implementation is already has activation function embedded inside the code, which I believe you didn't look through all of it. My implementation can fully use the same API as LSTMcell API. The only different is that it also return C state as the question asked. If all of this is unnecessary, please do ask a merge request to Keras github source to make this LSTM implementation more efficientCaltrop
have you tested your code? On your sample code I receive AttributeError: 'Mod_LSTMCELL' object has no attribute '_dropout_mask'Hinda
F
1

First, this is not possible do with the tf.keras.layers.LSTM. You have to use LSTMCell instead or subclass LSTM. Second, there is no need to subclass LSTMCell to get the sequence of cell states. LSTMCell already returns a list of the hidden state (h) and cell state (c) everytime you call it. For those not familiar with LSTMCell, it takes in the current [h, c] tensors, and the input at the current timestep (it cannot take in a sequence of times) and returns the activations, and the updated [h,c]. Here is an example of showing how to use LSTMCell to process a sequence of timesteps and to return the accumulated cell states.

# example inputs
inputs = tf.convert_to_tensor(np.random.rand(3, 4), dtype='float32')  # 3 timesteps, 4 features
h_c = [tf.zeros((1,2)),  tf.zeros((1,2))]  # must initialize hidden/cell state for lstm cell
h_c = tf.convert_to_tensor(h_c, dtype='float32')
lstm = tf.keras.layers.LSTMCell(2)

# example of how you accumulate cell state over repeated calls to LSTMCell
inputs = tf.unstack(inputs, axis=0)
c_states = []
for cur_inputs in inputs:
    out, h_c = lstm(tf.expand_dims(cur_inputs, axis=0), h_c)
    h, c = h_c
    c_states.append(c)
Flightless answered 19/9, 2020 at 17:41 Comment(0)
S
-3

You can access the states of any RNN by setting return_sequences = True in the initializer. You can find more information about this parameter here.

Stanleystanly answered 27/8, 2018 at 5:8 Comment(1)
For what I understood, 'return_sequence=True' returns all the hidden states (i.e., 'h' vector). What I want to access though is the cell states ('c' vector) of the LSTM.Beaux

© 2022 - 2024 — McMap. All rights reserved.