I'm loading a language model from torch hub (CamemBERT a French RoBERTa-based model) and using it do embed some french sentences:
import torch
camembert = torch.hub.load('pytorch/fairseq', 'camembert.v0')
camembert.eval() # disable dropout (or leave in train mode to finetune)
def embed(sentence):
tokens = camembert.encode(sentence)
# Extract all layer's features (layer 0 is the embedding layer)
all_layers = camembert.extract_features(tokens, return_all_hiddens=True)
embeddings = all_layers[0]
return embeddings
# Here we see that the shape of the embedding vector depends on the number of tokens in the sentence
u = embed(sentence="Bonjour, ça va ?")
u.shape # torch.Size([1, 7, 768])
v = embed(sentence="Salut, comment vas-tu ?")
v.shape # torch.Size([1, 9, 768])
Imagine now in order to do some semantic search, I want to calculate the cosine distance
between the vectors (tensors in our case) u
and v
cos = torch.nn.CosineSimilarity(dim=1)
cos(u, v) # will throw an error since the shape of `u` is different from the shape of `v`
I'm asking what is the best method to use in order to always get the same embedding shape for a sentence regardless the count of its tokens?
=> The first solution I'm thinking of is calculating the mean on axis=1
(embedding of a sentence is the mean embedding its tokens) since axis=0 and axis=2 have always the same size:
cos = torch.nn.CosineSimilarity(dim=1)
cos(u.mean(axis=1), v.mean(axis=1)) # works now and gives 0.7269
But, I'm afraid that I'm hurting the embedding of the sentence when calculating the mean since it gives the same weight for each token (maybe multiplying by TF-IDF?).
=> The second solution is to pad shorter sentences out. That means:
- giving a list of sentences to embed at a time (instead of embedding sentence by sentence)
- look up for the sentence with the longest tokens and embed it, get its shape
- for the rest of sentences embed then pad zero to get the same shape
(the sentence has 0 in the rest of dimensions)
What are your thoughts? What other techniques would you use and why?
Thanks in advance!