Why embed dimemsion must be divisible by num of heads in MultiheadAttention?
Asked Answered
Y

2

16

I am learning the Transformer. Here is the pytorch document for MultiheadAttention. In their implementation, I saw there is a constraint:

 assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

Why require the constraint: embed_dim must be divisible by num_heads? If we go back to the equation

MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOwhereheadi​=Attention(QWiQ​,KWiK​,VWiV​)

Assume: Q, K,V are n x emded_dim matrices; all the weight matrices W is emded_dim x head_dim,

Then, the concat [head_i, ..., head_h] will be a n x (num_heads*head_dim) matrix;

W^O with size (num_heads*head_dim) x embed_dim

[head_i, ..., head_h] * W^O will become a n x embed_dim output

I don't know why we require embed_dim must be divisible by num_heads.

Let say we have num_heads=10000, the resuts are the same, since the matrix-matrix product will absort this information.

Yardage answered 26/2, 2021 at 16:45 Comment(0)
B
5

From what I understood, it is a simplification they have added to keep things simple. Theoretically, we can implement the model like you proposed (similar to the original paper). In pytorch documention, they have briefly mentioned it.

Note that `embed_dim` will be split across `num_heads` (i.e. each head will have dimension `embed_dim` // `num_heads`)

Also, if you see the Pytorch implementation, you can see it is a bit different (optimised in my point of view) when comparing to the originally proposed model. For example, they use MatMul instead of Linear and Concat layer is ignored. Refer the below which shows the first encoder (with Btach size 32, 10 words, 512 features).

enter image description here

P.s: If you need to see the model params (like the above image), this is the code I used.

import torch
transformer_model = torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=1,num_decoder_layers=1,dim_feedforward=11)  # change params as necessary
tgt = torch.rand((20, 32, 512))
src = torch.rand((11, 32, 512))
torch.onnx.export(transformer_model, (src, tgt), "transformer_model.onnx")
Blen answered 3/1, 2023 at 12:39 Comment(0)
B
1

When you have a sequence of seq_len x emb_dim (ie. 20 x 8) and you want to use num_heads=2, the sequence will be split along the emb_dim dimension. Therefore you get two 20 x 4 sequences. You want every head to have the same shape and if emb_dim isn't divisible by num_heads this wont work. Take for example a sequence 20 x 9 and again num_heads=2. Then you would get 20 x 4 and 20 x 5 which are not the same dimension.

Briefs answered 27/2, 2021 at 9:49 Comment(1)
Good example, if seq_len x emb_dim is 20 x 9, and num_heads=2, let choose head_dim=77, then we can get the head_i is a 20 x 144 matrix. as such [head_1, head_2] is 20 x 288, We can still chose W^O is 288 x 9. we can still get the final 20 x 9. My point is that we can also map emb_dim into any lenght, and use W^O to project it back to emb_dim. Why need to dive emb_dim into even length? Thanks.Yardage

© 2022 - 2024 — McMap. All rights reserved.