Define custom LSTM Cell in Keras?
Asked Answered
B

1

7

I use Keras with TensorFlow as back-end. If I want to make a modification to an LSTM cell, such as "removing" the output gate, how can I do it? It is a multiplicative gate, so somehow I will have to set it to fixed values so that whatever multiplies it, has no effect.

Brachio answered 17/1, 2019 at 8:3 Comment(0)
G
11

First of all, you should define your own custom layer. If you need some intuition how to implement your own cell see LSTMCell in Keras repository. E.g. your custom cell will be:

class MinimalRNNCell(keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = K.dot(inputs, self.kernel)
        output = h + K.dot(prev_output, self.recurrent_kernel)
        return output, [output]

Then, use tf.keras.layers.RNN to use your cell:

cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)

# Here's how to use the cell to build a stacked RNN:

cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)
Gasholder answered 17/1, 2019 at 8:35 Comment(3)
Thank you. And how does this manual approach differ from managing LSTM state with the 'stateful' and other parameters?Brachio
It would be good to show at least the method arguments: build(self, batch_input_shape) and call(self, inputs, states), plus optionally training and mask, and what's returned by these methods: nothing for the build() method, and inputs and states for the call() method.Unassailable
Hi, I am trying to replicate the code for multiple inputs. So I changed y=layer([input_1,input_2]) and also change the shape of input_shape but its throwing errors as mentioned in #58408606. How to overcome the error ... any idea?Ziegler

© 2022 - 2024 — McMap. All rights reserved.