We have to go more low-level, as the pipeline
function is not appropriate for what you are trying to do.
After you pass your sequence to AutoModelForCausalLM
, the last tensor in the output will contain the probabilities of every token in the vocabulary being the next token. In the code below, I call it next_token_candidates_tensor
. After that, you simply need to select the indices of the topk candidates and decode them back to words.
import torch
from transformers import AutoModelForCausalLM , AutoTokenizer
class LMHeadModel:
def __init__(self, model_name):
# Initialize the model and the tokenizer.
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def get_predictions(self, sentence):
# Encode the sentence using the tokenizer and return the model predictions.
inputs = self.tokenizer.encode(sentence, return_tensors="pt")
with torch.no_grad():
outputs = self.model(inputs)
predictions = outputs[0]
return predictions
def get_next_word_probabilities(self, sentence, top_k=500):
# Get the model predictions for the sentence.
predictions = self.get_predictions(sentence)
# Get the next token candidates.
next_token_candidates_tensor = predictions[0, -1, :]
# Get the top k next token candidates.
topk_candidates_indexes = torch.topk(
next_token_candidates_tensor, top_k).indices.tolist()
# Get the token probabilities for all candidates.
all_candidates_probabilities = torch.nn.functional.softmax(
next_token_candidates_tensor, dim=-1)
# Filter the token probabilities for the top k candidates.
topk_candidates_probabilities = \
all_candidates_probabilities[topk_candidates_indexes].tolist()
# Decode the top k candidates back to words.
topk_candidates_tokens = \
[self.tokenizer.decode([idx]).strip() for idx in topk_candidates_indexes]
# Return the top k candidates and their probabilities.
return list(zip(topk_candidates_tokens, topk_candidates_probabilities))
sentence = "I enjoy walking in the"
model = LMHeadModel("gpt2")
model.get_next_word_probabilities(sentence, top_k=500)
# [('park', 0.15904344618320465),
# ('woods', 0.10028065741062164),
# ('streets', 0.0418376550078392),
# ('dark', 0.03117542900145054),
# ('door', 0.029618268832564354),
# ('street', 0.02388935722410679),
# ('rain', 0.021733922883868217),
# ...