To be clear, I am referring to "self-attention" of the type described in Hierarchical Attention Networks for Document Classification and implemented many places, for example: here. I am not referring to the seq2seq type of attention used in encoder-decoder models (i.e. Bahdanau), although my question might apply to that as well... I am just not as familiar with it.
Self-attention basically just computes a weighted average of RNN hidden states (a generalization of mean-pooling, i.e. un-weighted average). When there are variable length sequences in the same batch, they will typically be zero-padded to the length of the longest sequence in the batch (if using dynamic RNN). When the attention weights are computed for each sequence, the final step is a softmax, so the attention weights sum to 1.
However, in every attention implementation I have seen, there is no care taken to mask out, or otherwise cancel, the effects of the zero-padding on the attention weights. This seems wrong to me, but I fear maybe I am missing something since nobody else seems bothered by this.
For example, consider a sequence of length 2, zero-padded to length 5. Ultimately this leads to the attention weights being computed as the softmax of a similarly 0-padded vector, e.g.:
weights = softmax([0.1, 0.2, 0, 0, 0]) = [0.20, 0.23, 0.19, 0.19, 0.19]
and because exp(0)=1, the zero-padding in effect "waters down" the attention weights. This can be easily fixed, after the softmax operation, by multiplying the weights with a binary mask, i.e.
mask = [1, 1, 0, 0, 0]
and then re-normalizing the weights to sum to 1. Which would result in:
weights = [0.48, 0.52, 0, 0, 0]
When I do this, I almost always see a performance boost (in the accuracy of my models - I am doing document classification/regression). So why does nobody do this?
For a while I considered that maybe all that matters is the relative values of the attention weights (i.e., ratios), since the gradient doesn't pass through the zero-padding anyway. But then why would we use softmax at all, as opposed to just exp(.), if normalization doesn't matter? (plus, that wouldn't explain the performance boost...)