Inputs to the nn.MultiheadAttention?
Asked Answered
R

1

9

I have n-vectors which need to be influenced by each other and output n vectors with same dimensionality d. I believe this is what torch.nn.MultiheadAttention does. But the forward function expects query, key and value as inputs. According to this blog, I need to initialize a random weight matrix of shape (d x d) for each of q, k and v and multiply each of my vectors with these weight matrices and get 3 (n x d) matrices. Now are the q, k and v expected by torch.nn.MultiheadAttention just these three matrices or do I have it mistaken?

Regalado answered 9/1, 2021 at 12:51 Comment(0)
S
15

When you want to use self attention, just pass your input vector into torch.nn.MultiheadAttention for the query, key and value.


attention  = torch.nn.MultiheadAttention(<input-size>, <num-heads>)

x, _ = attention(x, x, x)

The pytorch class returns the output states (same shape as input) and the weights used in the attention process.

Sunstroke answered 9/1, 2021 at 16:34 Comment(6)
Are the inputs x supposed to be sequences of token ids or embeddings?Poston
x is the embeddingsAmye
this also assumes the k,q,v-dimensions are the sameVaporizer
Slightly cleaner: x = attention(x, x, x, need_weights=False)Reward
I don't think this is correct. The MultiheadAttention module only has out_proj layer within it. It doesn't create the Q, K, V matrices implicitly. So you passing x 3 times makes it so that Q, K, V are all just x. This will run, and probably train/learn well enough. But it's not the correct way of doing self attention. What makes self-attention separate from cross attention is that in self attention you use x to generate Q, K, V. Whereas in cross attention you have x, y and you use x to generate Q, and y to generate K, V.Roofer
@SergeyBokhnyak three years later Im also not very confident in this answer. So x should be passed through three seperate linear layers first in order to generate q, k and v and then those are passed into attention function right?Sunstroke

© 2022 - 2024 — McMap. All rights reserved.