I am learning to apply Transform model proposed by Attention Is All You Need from tensorflow official document Transformer model for language understanding.
As section Positional encoding says:
Since this model doesn't contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence.
The positional encoding vector is added to the embedding vector.
My understanding is to add positional encoding vector
directly to embedding vector
. But I found embedding vector
multiplied by a constant when I looked at the code.
The code in section Encoder as follows:
class Encoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(input_vocab_size, self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x # (batch_size, input_seq_len, d_model)
We can see x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
before x += self.pos_encoding[:, :seq_len, :]
So why does embedding vector multiplied by a constant before adding positional encoding in Transformer model?
referring to the normalizing factor in equation (1) in the paper – Transitoryself-attention
should be after the addition ofembedding vector
andpositional encoding
, so I can't understandembedding vector
multiplying by a constant. – Wind#Scale embedding by the sqrt of the hidden size
. – Wind