why take the first hidden state for sequence classification (DistilBertForSequenceClassification) by HuggingFace
Asked Answered
H

1

9

In the last few layers of sequence classification by HuggingFace, they took the first hidden state of the sequence length of the transformer output to be used for classification.

hidden_state = distilbert_output[0]  # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0]  # (bs, dim)           <-- first hidden state
pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = self.dropout(pooled_output)  # (bs, dim)
logits = self.classifier(pooled_output)  # (bs, dim)

Is there any benefit to taking the first hidden state over the last, average, or even the use of a Flatten layer instead?

Hobbism answered 6/2, 2020 at 4:10 Comment(0)
R
6

Yes, this is directly related to the way that BERT is trained. Specifically, I encourage you to have a look at the original BERT paper, in which the authors introduce the meaning of the [CLS] token:

[CLS] is a special symbol added in front of every input example [...].

Specifically, it is used for classification purposes, and therefore the first and simplest choice for any fine-tuning for classification tasks. What your relevant code fragment is doing, is basically just extracting this [CLS] token.

Unfortunately, the DistilBERT documentation of Huggingface's library does not explicitly refer to this, but you rather have to check out their BERT documentation, where they also highlight some issues with the [CLS] token, analogous to your concerns:

Alongside MLM, BERT was trained using a next sentence prediction (NSP) objective using the [CLS] token as a sequence approximate. The user may use this token (the first token in a sequence built with special tokens) to get a sequence prediction rather than a token prediction. However, averaging over the sequence may yield better results than using the [CLS] token.

Roadside answered 20/2, 2020 at 9:3 Comment(2)
+1. If averaging over the embeddings of the sequence could yield better results, why those authors didn't adopt this approach?Zilpah
I presume that the alternative is more compute-intensive and therefore not worth the (maybe only marginal) gains.Roadside

© 2022 - 2024 — McMap. All rights reserved.