what the difference between att_mask and key_padding_mask in MultiHeadAttnetion
Asked Answered
F

2

22

What the difference between att_mask and key_padding_mask in MultiHeadAttnetion of pytorch:

key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored

attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Thanks in advance.

Fumble answered 29/6, 2020 at 0:31 Comment(0)
D
30

The key_padding_mask is used to mask out positions that are padding, i.e., after the end of the input sequence. This is always specific to the input batch and depends on how long are the sequence in the batch compared to the longest one. It is a 2D tensor of shape batch size × input length.

On the other hand, attn_mask says what key-value pairs are valid. In a Transformer decoder, a triangle mask is used to simulate the inference time and prevent the attending to the "future" positions. This is what att_mask is usually used for. If it is a 2D tensor, the shape is input length × input length. You can also have a mask that is specific to every item in a batch. In that case, you can use a 3D tensor of shape (batch size × num heads) × input length × input length. (So, in theory, you can simulate key_padding_mask with a 3D att_mask.)

Decidua answered 29/6, 2020 at 7:51 Comment(3)
What would be the purpose of having a mask that is specific to every item in the batch? Curious.Gutter
there could be padding at diff positions for each item in batch. For e.g. if input is a series of sentences, and they are padded at the beginning or end, we need to apply a individual mask to each sentence. This mask will be a combination of attn_mask and key_padding_mask in case of a decoder (referring to encoder inputs for key, values)Cabinet
when passing masks for each item in a batch, does the module use sequential items along the 0 dimension for each attention head? i.e. when batch_size=32 and num_heads=4, are att_mask[:4,:,:] the masks for item 1 (for head 1, 2, 3 and 4)?Wildermuth
D
3

I think they work as the same: Both of the mask defines which attention between query and key will not be used. And the only difference between the two choices is in which shape you are more comfortable to input the mask

According to the code, it seems like the two mask are merged/taken union so they all play the same role -- which attention between query and key will not be used. As they are taken union: the two mask inputs can be different valued if it is necessary that you are using two masks, or you can input the mask in whichever mask_args according to whose required shape is convenient: Here is part of the original code from pytorch/functional.py around line 5227 in the function multi_head_attention_forward()

...
# merge key padding and attention masks
if key_padding_mask is not None:
    assert key_padding_mask.shape == (bsz, src_len), \
        f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
    if attn_mask is None:
        attn_mask = key_padding_mask
    elif attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.logical_or(key_padding_mask)
    else:
        attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

Please correct me if you have different opinions or I am wrong.

Dewdrop answered 6/12, 2021 at 21:9 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.