Implementing attention in Keras classification
Asked Answered
C

0

7

I would like to implement attention to a trained image classification CNN model. For example, there are 30 classes and with the Keras CNN, I obtain for each image the predicted class. However, to visualize the important features/locations of the predicted result. I want to add a Soft Attention after the FC layer. I tried to read the "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention" to obtain similar results. However, I could not understand how the author has implemented. Because my problem is not an image caption or text seq2seq problem.

I have an image classification CNN and would like to extract the features and put it into a LSTM to visualize soft attention. Although I am getting stuck every time.

The steps I took:

  1. Load CNN model
  2. Extract features from a single image (however, the LSTM will check the same image with some removed patches in the image)

The steps I took:

  1. Load CNN model (I already trained the CNN earlier for predictions)
  2. Extract features from a single image (however, the LSTM will check the same image with some removed patches in the image)

Getting stuck after steps below:

  1. Create LSTM with soft attention
  2. Obtain a single output

I am using Keras with TensorFlow background. CNN features are extracted using ResNet50. The images are 224x224 and the FC layer has 2048 units as output shape.

#Extract CNN features:

base_model = load_model(weight_file, custom_objects={'custom_mae': custom_mae})
last_conv_layer = base_model.get_layer("global_average_pooling2d_3")
cnn_model = Model(input=base_model.input, output=last_conv_layer.output)
cnn_model.trainable = False
bottleneck_features_train_v2 = cnn_model.predict(train_gen.images)


#Create LSTM:    

seq_input = Input(shape=(1, 224, 224, 3 ))
encoded_frame = TimeDistributed(cnn_model)(seq_input)
encoded_vid = LSTM(2048)(encoded_frame) 
lstm = Dropout(0.5)(encoded_vid)

#Add soft attention

attention = Dense(1, activation='tanh')(lstm)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(units)(attention)
attention = Permute([2, 1])(attention)


#output 101 classes
predictions = Dense(101, activation='softmax', name='pred_age')(attention)

What I expect is to give an image feature from the last FC layer. Add soft attention to LSTM to train attention weights and would like to obtain a class from the output and visualize the soft attention to know where the system is looking at with doing the prediction (similar soft attention visualization as in the paper).

As I am new to the attention mechanism, I did much research and could not find a solution/understanding of my problem. I would like to know if I am doing it right.

Colon answered 16/7, 2019 at 14:15 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.