How does Beam Search operate on the output of The Transformer?
Asked Answered
B

2

10

According to my understanding (please correct me if I'm wrong), Beam Search is BFS where it only explores the "graph" of possibilities down b the most likely options, where b is the beam size.

To calculate/score each option, especially for the work that I'm doing which is in the field of NLP, we basically calculate the score of a possibility by calculating the probability of a token, given everything that comes before it.

This makes sense in a recurrent architecture, where you simply run the model you have with your decoder through the best b first tokens, to get the probabilities of the second tokens, for each of the first tokens. Eventually, you get sequences with probabilities and you just pick the one with the highest probability.

However, in the Transformer architecture, where the model doesn't have that recurrence, the output is the entire probability for each word in the vocabulary, for each position in the sequence (batch size, max sequence length, vocab size). How do I interpret this output for Beam Search? I can get the encodings for the input sequence, but since there isn't that recurrence of using the previous output as input for the next token's decoding, how do I go about calculating the probability of all the possible sequences that stems from the best b tokens?

Bare answered 19/6, 2019 at 20:51 Comment(0)
F
10

The beam search works exactly in the same as with the recurrent models. The decoder is not recurrent (it's self-attentive), but it is still auto-regressive, i.e., generating a token is conditioned on previously generated tokens.

At the training time, the self-attention is masked, such that in only attend to words to the left from the word that is currently generated. It simulates the setup you have at inference time when you indeed only have the left context (because the right context has not been generated yet).

The only difference is that in the RNN decoder, you only use the last RNN state in every beam search step. With the Transformer, you always need to keep the entire hypothesis and do the self-attention over the entire left context.

Flump answered 20/6, 2019 at 8:59 Comment(2)
Thank you for your answer! Is there any resource that explains this in greater details that you could point me to?Bare
I guess what I really want to ask is that, with an RNN architecture, in the decoder, I can feed it the b tokens that are highest in probability, to get the conditional probabilities of subsequent tokens. However, as I understand, from this tutorial here: tensorflow.org/beta/tutorials/text/…, I can't really do that for the Transformer architecture. Is that right? The decoder takes in the encoder outputs, the 2 masks and the target -- what would I input in for the parameter target?Bare
C
2

Adding more information for your later question and for people who have the same question:

I guess what I really want to ask is that, with an RNN architecture, in the decoder, I can feed it the b tokens that are highest in probability, to get the conditional probabilities of subsequent tokens. However, as I understand, from this tutorial here: tensorflow.org/beta/tutorials/text/…, I can't really do that for the Transformer architecture. Is that right? The decoder takes in the encoder outputs, the 2 masks and the target -- what would I input in for the parameter target?

The tutorial on the website you mentioned is using teacher forcing in the training stage. And it's possible to apply beam-search for the decoder of transformers in the testing stage.

Using beam-search for modern architecture like transformers in the training stage is not so popular. (Check this link for more info) while teacher forcing as the tutorial mentioned in the training stage, can offer you parallel computation and speed up training once you are dealing with a large vocabulary-list task.

As for testing such a decoder, you could try the following steps to do beam-search (Just offering a possibility based on my understanding and there may have more better solutions):

First, Instead of taking the entire ground truth sequence as input for the decoder, you could only provide "[SOS]" and pad the rest positions. Although output of your decoder is still [batch_size, max_sequence_len, vocab_size], only the (batch_size, 0, vocab_size) is giving you useful information and that is the first token your model generated. Select top b token and add to your "[SOS]" sequence. Now you have "[SOS] token(1,1)", ... , "[SOS], token(1,b)" sequences.

Second, use the above sequences as input for the decoder and search for the top b token among b * vocab_size options. Add them to their corresponding sequence. Repeat until sequcences meet some restriction (max_ouput_length or [EOS])

P.S: 1) [SOS] or [EOS] means the Start or the End of the sequence. 2) token(i,j) means the j-th token in top b tokens for the i-th token in sequence

Calciferous answered 6/4, 2022 at 9:37 Comment(2)
Thanks for this comment! Just to clarify, following the tutorial, how would you select the top b tokens? As in... at which point exactly are the tokens retrieved so I could order them and get the top b only? Would they be in the 'Decoder' class or in the individual decoder Layers?Pasteurize
The output of the decoder provides a distribution per token. I guess once you received the distribution for the next token, you could implement the beam search and construct new b sequences for the next iteration.Calciferous

© 2022 - 2024 — McMap. All rights reserved.