I'm trying to implement a beam search decoding strategy in a text generation model. This is the function that I am using to decode the output probabilities.
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - torch.log(row[j])]
all_candidates.append(candidate)
# sort candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
sequences = ordered[:k]
return sequences
Now you can see this function is implemented with batch_size 1 in mind. Adding another loop for batch size would make the algorithm O(n^4)
. It is slow as it is now. Is there any way to improve the speed of this function. My model output is usually of the size (32, 150, 9907)
which follows the format (batch_size, max_len, vocab_size)
batch_size=1
and parallelize the processing of the test examples? – Chinatownregister_buffer
to cache the inputs of the previous timestep, so that only the new input is fed in the current timestep and is considerably fast. – Indemnity