How to add Attention layer between two LSTM layers in Keras
Asked Answered
H

1

6

I am trying to add an Attention layer between the encoder LSTM(many to many) and the decoder LSTM(many to one).

But my code seem making the attention layer for only one Decoder LSTM input.

How can I apply the Attention layer to all the inputs of the decoder LSTM? (output of Attention layer = (None,1440,984) )

This is the summary of the attention layer of my model.

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 1440, 5)      0
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) (None, 1440, 984)    1960128     input_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1440, 1)      985         bidirectional_1[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 1440)         0           dense_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 1440)         0           flatten_1[0][0]
__________________________________________________________________________________________________
repeat_vector_1 (RepeatVector)  (None, 984, 1440)    0           activation_1[0][0]
__________________________________________________________________________________________________
permute_1 (Permute)             (None, 1440, 984)    0           repeat_vector_1[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, 1440, 984)    0           bidirectional_1[0][0]
                                                                 permute_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 984)          0           multiply_1[0][0]
==================================================================================================
Total params: 1,961,113
Trainable params: 1,961,113
Non-trainable params: 0
__________________________________________________________________________________________________

here is my code

_input = Input(shape=(self.x_seq_len, self.input_x_shape), dtype='float32')
activations = Bidirectional(LSTM(self.hyper_param['decoder_units'], return_sequences=True), input_shape=(self.x_seq_len, self.input_x_shape,))(_input)

# compute importance for each step
attention = Dense(1, activation='tanh')(activations) 
attention = Flatten()(attention)
attention = Activation('softmax')(attention) 
attention = RepeatVector(self.hyper_param['decoder_units']*2)(attention)
attention = Permute([2, 1])(attention)

sent_representation = Multiply()([activations, attention])
sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(self.hyper_param['decoder_units']*2,))(sent_representation)

attn = Model(input=_input, output=sent_representation)
model.add(attn)
#decoder
model.add(LSTM(self.hyper_param['encoder_units'], return_sequences=False, input_shape=(None, self.hyper_param['decoder_units'] * 2 )))
Hiss answered 23/12, 2018 at 8:48 Comment(0)
F
0

Attention is meant to iteratively take one decoder output value (the last hidden state) and then using this 'query', 'attend' to all the 'values' which is nothing but the entire list of encoder output.

So input1 = decoder hidden state of prev timestep: the 'key'

input2 = all encoder hidden states: the 'value's

output = the context: weighted sum of all the encoder hidden states

Use the context, prev hidden state of decoder and the prev translated output to generate the next word and a new hidden output state and then repeat the above process all over again until 'EOS' is encountered.

Your attention logic itself is perfect (excluding the last line involving the decoder). But the rest of your code is missing. If you can share the complete code, I can help you with the error. I see no mistake in the attention logic you have defined.

For more specific details, please refer https://towardsdatascience.com/create-your-own-custom-attention-layer-understand-all-flavours-2201b5e8be9e

Feodor answered 16/11, 2020 at 8:25 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.