I am trying to implement the attention described in Luong et al. 2015 in PyTorch myself, but I couldn't get it work. Below is my code, I am only interested in the "general" attention case for now. I wonder if I am missing any obvious error. It runs, but doesn't seem to learn.
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.embedding = nn.Embedding(
num_embeddings=self.output_size,
embedding_dim=self.hidden_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size, self.hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
# s: softmax
self.Ws = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
gru_out, hidden = self.gru(embedded, hidden)
# [0] remove the dimension of directions x layers for now
attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
context = torch.mm(attn_weights, encoder_outputs)
# hc: [hidden: context]
out_hc = F.tanh(self.Whc(torch.cat([hidden[0], context], dim=1)) # eq.5
output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6
return output, hidden, attn_weights
I have studied the attention implemented in
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
and
https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb
- The first one isn't the exact attention mechanism I am looking for. A major disadvantage is that its attention depends on the sequence length (
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
), which could be expensive for long sequences. - The second one is more similar to what's described in the paper, but still not the same as there is not
tanh
. Besides, it is really slow after updating it to latest version of pytorch (ref). Also I don't know why it takes the last context (ref).