Batch-wise beam search in pytorch
Asked Answered
P

3

6

I'm trying to implement a beam search decoding strategy in a text generation model. This is the function that I am using to decode the output probabilities.

def beam_search_decoder(data, k):
    sequences = [[list(), 0.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score - torch.log(row[j])]
                all_candidates.append(candidate)
        # sort candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        sequences = ordered[:k]
    return sequences

Now you can see this function is implemented with batch_size 1 in mind. Adding another loop for batch size would make the algorithm O(n^4). It is slow as it is now. Is there any way to improve the speed of this function. My model output is usually of the size (32, 150, 9907) which follows the format (batch_size, max_len, vocab_size)

Polite answered 14/10, 2020 at 15:42 Comment(2)
The beam-search strategy has sense during test time. Can't you maintain a batch_size=1 and parallelize the processing of the test examples?Chinatown
You may also have a look at the beam search implementation and code in this repo for image captioning using a modified Transformer. The implementation makes use PyTorch's register_buffer to cache the inputs of the previous timestep, so that only the new input is fed in the current timestep and is considerably fast.Indemnity
H
9

Below is my implementation, which may be a little bit faster than the for loop implementation.

import torch


def beam_search_decoder(post, k):
    """Beam Search Decoder

    Parameters:

        post(Tensor) – the posterior of network.
        k(int) – beam size of decoder.

    Outputs:

        indices(Tensor) – a beam of index sequence.
        log_prob(Tensor) – a beam of log likelihood of sequence.

    Shape:

        post: (batch_size, seq_length, vocab_size).
        indices: (batch_size, beam_size, seq_length).
        log_prob: (batch_size, beam_size).

    Examples:

        >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)
        >>> indices, log_prob = beam_search_decoder(post, 3)

    """

    batch_size, seq_length, _ = post.shape
    log_post = post.log()
    log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for i in range(1, seq_length):
        log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)
        indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
    return indices, log_prob
Hairsplitter answered 18/2, 2021 at 10:20 Comment(2)
Shouldn't the indices be between 0 and vocab_size? In this implementation, they are between 0 and vocab_size*kCiapha
You are right! This returns indices from the flatten array (c.f. log_prob.view) which can exceed the vocabulary size. This version also overlooks sequences sharing a portion of tokens at their start.Celibacy
C
1
/!\ The most upvoted answer doesn't perform a correct beam-search!

Based on the version proposed by 防暴队大盾, I decided to implement a version of the beam-search algorithm that does not overlook sequences that share initial tokens. This is done by retrieving correct indices from the indices of the flatten array

def beam_search(prediction, k=10):
    batch_size, seq_length, vocab_size = prediction.shape
    log_prob, indices = prediction[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for n1 in range(1, seq_length):
        log_prob_temp = log_prob.unsqueeze(-1) + prediction[:, n1, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index_temp = log_prob_temp.view(batch_size, -1).topk(k, sorted=True)
        idx_begin = index_temp // vocab_size  # retrieve index of start sequence
        idx_concat = index_temp % vocab_size  # retrieve index of new token
        new_indices = torch.zeros((batch_size, k, n1+1), dtype=torch.int64)
        for n2 in range(batch_size):
            new_indices[n2, :, :-1] = indices[n2][idx_begin[n2]]
            new_indices[n2, :, -1] = idx_concat[n2]
        indices = new_indices
    return indices, log_prob

This version assumes that prediction corresponds to the cross-entropy scores, not the probability. Therefore no need to take the log here.

If someone knows how to avoid the inner-most loop with some fancy indexing, one can probably make this even faster.

Celibacy answered 11/7, 2023 at 11:23 Comment(0)
S
-2

You can use this library

https://pypi.org/project/pytorch-beam-search/

It implements Beam Search, Greedy Search and sampling for PyTorch sequence models.

The following snippet implements a Transformer seq2seq model and uses it to generate predictions.

#pip install pytorch-beam-search
from pytorch_beam_search import seq2seq

# Create vocabularies
# Tokenize the way you need
source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]
# An Index object represents a mapping from the vocabulary to
# to integers (indices) to feed into the models
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Create tensors
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)

# Create and train the model
model = seq2seq.Transformer(source_index, target_index)    # just a PyTorch model
model.fit(X, Y, epochs = 100)    # basic method included

# Generate new predictions
new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new)    # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new) 
output = [target_index.tensor2text(p) for p in predictions]
output
Salivation answered 18/10, 2021 at 17:12 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.