Warning:
You need to check if the produced sentence embeddings are meaningful, this is required because the model you are using wasn't trained to produce meaningful sentence embeddings (check this StackOverflow answer for further information).
The field of retrieving sentence embeddings from LLM's is an ongoing research topic. In the following, I will show two different approaches that could be used to retrieve sentence embeddings from Llama 2.
Weighted-Mean-Pooling
Llama is a decoder with left-to-right attention. The idea behind weighted-mean_pooling is that the tokens at the end of the sentence should contribute more than the tokens at the beginning of the sentence because their weights are contextualized with the previous tokens, while the tokens at the beginning have far less context representation.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-chat-hf"
t = AutoTokenizer.from_pretrained(model_id)
t.pad_token = t.eos_token
m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto" )
m.eval()
texts = [
"this is a test",
"this is another test case with a different length",
]
t_input = t(texts, padding=True, return_tensors="pt")
with torch.no_grad():
last_hidden_state = m(**t_input, output_hidden_states=True).hidden_states[-1]
weights_for_non_padding = t_input.attention_mask * torch.arange(start=1, end=last_hidden_state.shape[1] + 1).unsqueeze(0)
sum_embeddings = torch.sum(last_hidden_state * weights_for_non_padding.unsqueeze(-1), dim=1)
num_of_none_padding_tokens = torch.sum(weights_for_non_padding, dim=-1).unsqueeze(-1)
sentence_embeddings = sum_embeddings / num_of_none_padding_tokens
print(t_input.input_ids)
print(weights_for_non_padding)
print(num_of_none_padding_tokens)
print(sentence_embeddings.shape)
Output:
tensor([[ 1, 445, 338, 263, 1243, 2, 2, 2, 2, 2],
[ 1, 445, 338, 1790, 1243, 1206, 411, 263, 1422, 3309]])
tensor([[ 1, 2, 3, 4, 5, 0, 0, 0, 0, 0],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
tensor([[15],
[55]])
torch.Size([2, 4096])
Prompt-based last token
Another alternative is to use a certain prompt and utilize the contextualized embedding of the last token. This approach was introduced by: Jiang et al. and showed decent results for the OPT model family without finetuning. The idea is to force the model with a certain prompt to predict exactly one word. They call it PromptEOL
and used the following implementation for their experiments:
"This sentence: {text} means in one word:"
Please check their paper for further results. You can use the following code to utilize their approach with Llama:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-chat-hf"
t = AutoTokenizer.from_pretrained(model_id)
t.pad_token = t.eos_token
m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto" )
m.eval()
texts = [
"this is a test",
"this is another test case with a different length",
]
prompt_template = "This sentence: {text} means in one word:"
texts = [prompt_template.format(text=x) for x in texts]
t_input = t(texts, padding=True, return_tensors="pt")
with torch.no_grad():
last_hidden_state = m(**t_input, output_hidden_states=True, return_dict=True).hidden_states[-1]
idx_of_the_last_non_padding_token = t_input.attention_mask.bool().sum(1)-1
sentence_embeddings = last_hidden_state[torch.arange(last_hidden_state.shape[0]), idx_of_the_last_non_padding_token]
print(idx_of_the_last_non_padding_token)
print(sentence_embeddings.shape)
Output:
tensor([12, 17])
torch.Size([2, 4096])