How do I implement this attention layer in PyTorch?
Asked Answered
F

1

5

I already did the implementation of the CNN part and everything seems to be working just fine. Afterwards started to implement the LSTM part and, If I understood it right, the output shape should be (batch_size, 256) (because it's bidirectional, 1 layer and 128 units). Guess everything is alright either.

But what I am trying to figure out is how to implement that attention layer. To my understanding, is basically a weight tensor that will be multiplied by the LSTM output, then applies the softmax function and feed it to the final linear layer. My questions are:

  • Did I understand it right? Is that simple?

  • What is the size of the weight tensor? (128), (256), (2, 128) or something else?

  • How to do the tensor multiplication properly? At my first try I created the weight tensor as a torch Linear with equal values for in_features and out_features (256). After that I applied torch.mul to the input (LSTM output) and the weights. Is that right?

Model Architecture Fig 1
Model Architecture Fig 2
Attention Layer Info Fig 1

Here is the code snippet of the Attention Layer I tried to implement:

class Attention_Layer(nn.Module):
    def __init__(self, n_feats: int) -> None:
        super().__init__()
        self.w = nn.Linear(
            in_features=n_feats,
            out_features=n_feats
        )
    
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        w = self.w(X)
        output = F.softmax(torch.mul(X, w), dim=1)
        return output

And the code snippet of the full model architecture:

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.in_channels = 5
        self.linear_input_features = 1103872
        
        self.cnn = nn.Sequential(
            FLB(
                input_channels=self.in_channels,
                output_channels=64,
                kernel_size=(2, 2)
            ),
            nn.MaxPool2d(kernel_size=(2, 2)),
            FLB(
                input_channels=64,
                output_channels=128,
                kernel_size=(2, 2)
            ),
            FLB(
                input_channels=128,
                output_channels=256,
                kernel_size=(2, 2)
            ),
            nn.MaxPool2d(kernel_size=(2, 2)),
            FLB(
                input_channels=256,
                output_channels=512,
                kernel_size=(2, 2)
            ),
            FLB(
                input_channels=512,
                output_channels=512,
                kernel_size=(2, 2)
            ),
            nn.Flatten(),
            nn.Linear(
                in_features=self.linear_input_features,
                out_features=128
            )
        )
        
        self.lstm = nn.Sequential(
            nn.LSTM(
                input_size=128,
                hidden_size=128,
                num_layers=1,
                batch_first=True,
                bidirectional=True
            ),
            Extract_LSTM_Output()
        )
        
        self.model = nn.Sequential(
            self.cnn,
            self.lstm,
            Attention_Layer(256),
            nn.Linear(
                in_features=256,
                out_features=3
            )
        )
        self.model.apply(weight_init)
    
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.model(X)
Fixate answered 9/7, 2023 at 17:7 Comment(0)
B
8

Your understanding of the attention mechanism is on the right track. In the context of your LSTM model, the attention layer is indeed about assigning weights to the LSTM output before feeding it into the final linear layer. The weights in this context are learned parameters that help the model focus on more relevant parts of the input sequence.

Regarding the implementation of your attention layer, I've noticed a few aspects that might need adjustment. The attention mechanism typically involves a query-key-value framework, even in self-attention scenarios where these are derived from the same source. Here's a revised version of the attention layer using PyTorch, tailored for self-attention:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttentionLayer(nn.Module):
    def __init__(self, feature_size):
        super(SelfAttentionLayer, self).__init__()
        self.feature_size = feature_size

        # Linear transformations for Q, K, V from the same source
        self.key = nn.Linear(feature_size, feature_size)
        self.query = nn.Linear(feature_size, feature_size)
        self.value = nn.Linear(feature_size, feature_size)

    def forward(self, x, mask=None):
        # Apply linear transformations
        keys = self.key(x)
        queries = self.query(x)
        values = self.value(x)

        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.feature_size, dtype=torch.float32))

        # Apply mask (if provided)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)

        # Multiply weights with values
        output = torch.matmul(attention_weights, values)

        return output, attention_weights

This implementation provides a more standard approach to self-attention, which may enhance your model's capability to focus on relevant features within the LSTM output. The feature_size should be set according to your LSTM's output features. In your case, if the LSTM output is (batch_size, 256), then feature_size would be 256.

Brotherhood answered 12/11, 2023 at 14:30 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.