How to compute mean/max of HuggingFace Transformers BERT token embeddings with attention mask?
Asked Answered
B

3

8

I'm using the HuggingFace Transformers BERT model, and I want to compute a summary vector (a.k.a. embedding) over the tokens in a sentence, using either the mean or max function. The complication is that some tokens are [PAD], so I want to ignore the vectors for those tokens when computing the average or max.

Here's an example. I initially instantiate a BertTokenizer and a BertModel:

import torch
import transformers
from transformers import AutoTokenizer, AutoModel

transformer_name = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True)

model = AutoModel.from_pretrained(transformer_name)

I then input some sentences into the tokenizer and get out input_ids and attention_mask. Notably, an attention_mask value of 0 means that the token was a [PAD] that I can ignore.

sentences = ['Deep learning is difficult yet very rewarding.',
             'Deep learning is not easy.',
             'But is rewarding if done right.']
tokenizer_result = tokenizer(sentences, max_length=32, padding=True, return_attention_mask=True, return_tensors='pt')

input_ids = tokenizer_result.input_ids
attention_mask = tokenizer_result.attention_mask

print(input_ids.shape) # torch.Size([3, 11])

print(input_ids)
# tensor([[  101,  2784,  4083,  2003,  3697,  2664,  2200, 10377,  2075,  1012,  102],
#         [  101,  2784,  4083,  2003,  2025,  3733,  1012,   102,     0,     0,    0],
#         [  101,  2021,  2003, 10377,  2075,  2065,  2589,  2157,  1012,   102,   0]])

print(attention_mask.shape) # torch.Size([3, 11])

print(attention_mask)
# tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#         [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
#         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])

Now, I call the BERT model to get the 768-D token embeddings (the top-layer hidden states).

model_result = model(input_ids, attention_mask=attention_mask, return_dict=True)

token_embeddings = model_result.last_hidden_state
print(token_embeddings.shape) # torch.Size([3, 11, 768])

So at this point, I have:

  1. token embeddings in a [3, 11, 768] matrix: 3 sentences, 11 tokens, 768-D vector for each token.
  2. attention mask in a [3, 11] matrix: 3 sentences, 11 tokens. A 1 value indicates non-[PAD].

How do I compute the mean / max over the vectors for the valid, non-[PAD] tokens?

I tried using the attention mask as a mask and then called torch.max(), but I don't get the right dimensions:

masked_token_embeddings = token_embeddings[attention_mask==1]
print(masked_token_embeddings.shape) # torch.Size([29, 768] <-- WRONG. SHOULD BE [3, 11, 768]

pooled = torch.max(masked_token_embeddings, 1)
print(pooled.values.shape) # torch.Size([29]) <-- WRONG. SHOULD BE [3, 768]

What I really want is a tensor of shape [3, 768]. That is, a 768-D vector for each of the 3 sentences.

Balthazar answered 1/12, 2020 at 1:38 Comment(0)
A
8

For max, you can multiply with attention_mask:

pooled = torch.max((token_embeddings * attention_mask.unsqueeze(-1)), axis=1)

For mean, you can sum along the axis and divide by attention_mask along that axis:

mean_pooled = token_embeddings.sum(axis=1) / attention_mask.sum(axis=-1).unsqueeze(-1)
Activity answered 1/12, 2020 at 3:27 Comment(4)
Thank you. Couldn't mean_pooled be implemented analogously to the max pooling, except you'd use torch.mean? Like this: mean_pooled = torch.mean((token_embeddings * attention_mask.unsqueeze(-1)), axis=1) ?Balthazar
Using token_embeddings * attention_mask.unsqueeze(-1) is pretty slick. I didn't think of that.Balthazar
This seems incorrect - PAD tokens are not embedded as 0-vectors; they represent context and position and thus have real valued elements. Therefore a sum over the token-wise axis will be computed with the PAD vectors, which probably isn't what we want. One work around I just came up with was setting the embeddings to 0 if the attention mask was 0, then summing, then dividing by # of tokens. In code, state[inputs["attention_mask"] == 0] = 0. This feels inefficient though, so I'm on the lookout for more elegant solutionsFalchion
@AlexL: You wrote This seems incorrect. What This are you referring to?Balthazar
C
4

In addition to @Quang, you can have a look at sentence_transformers Pooling layer.

For max pooling, they do this:

input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
pooled = torch.max(token_embeddings, 1)[0]

And for mean pooling they do the following:


input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
pooled = sum_embeddings / sum_mask

The max pooling presented in the accepted answer will suffer when the max is negative, and the implementation from sentence transformers changes token_embeddings, which throw an error when you want to use the embedding for back propagation: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

If you're interested on anything back-prop related, you can do something like this:

input_mask_expanded = torch.where(attention_mask==0, -1e-9, 0.).unsqueeze(-1).expand(token_embeddings.size()).float()
pooled = torch.max(token_embeddings-input_mask_expanded, 1)[0] # Set padding tokens to large negative value

It's the same idea of making all masked tokens to be very small, but it doesn't change the token_embeddings while at it.

Chadchadabe answered 7/9, 2022 at 17:39 Comment(1)
Why do we need torch.clamp(sum_mask, min=1e-9)? There must be a precision problem somewhere. But I don't really know.Huddleston
R
0

Alex is right. Look on hidden states for strings that go into tokenizer. For different strings, padding will have different embeddings.

So, in order to properly pool embeddings, you need to ignore those padding vectors.

Let's say you want to get embeddings out of the last 4 layers of BERT (as it yields the best classification results):

#iterate over the last 4 layers and get embeddings for 
#strings without having embeddings from PAD tokens
m = []   
for i in range(len(hidden_states[0])):
   m.append([hidden_states[j+9][i,:,:][tokens["attention_mask"][i] !=0] for j in range(4)]) 

#average over all tokens embeddings
means = []
for i in range(len(hidden_states[0])):
    means.append(torch.stack(m[i]).mean(dim=1))

#stack embeddings for all strings
pooled = torch.stack(means).reshape(-1,1,3072)
Resurrectionism answered 16/1, 2022 at 21:29 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.