Concept explanation
Before I answer your question, it is crucial to understand how the entailment approach for zero-shot text classification works. This approach requires a model that was trained for NLI, which means, that it is able to determine if the hypothesis
is:
- supported,
- not supported,
- undetermined
by a given premise
[1]. You can verify that for the model you mentioned with the following code:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
# It will output three logits
print(nli_model.classification_head.out_proj)
# Each vector corresponds to the following labels
print(nli_model.config.id2label)
Output:
Linear(in_features=1024, out_features=3, bias=True)
{0: 'contradiction', 1: 'neutral', 2: 'entailment'}
The entailment approach, proposed by Yin et. al, utilizes these NLI capabilities by using the text as premise
and formulating a hypothesis
for each possible class with the template:
"the text is about {}”
That means when you have a text and three potential classes, you will pass three sequences to the NLI model and compare the entailment logits to classify the text.
Finetuning
To fine-tune an NLI model on your annotated data, you, therefore, need to formulate your text classification task as an NLI task! That means, you need to generate premises
and the labels need to be either contradiction
or entailment
. The contradiction
label is included to avoid the model only seeing hypotheses that are entailed by their respective premise (i.e. the model needs to learn contraction to predict a low score for entailment for the zero-shot text classification task).
The following code shows you an example of how to prepare your dataset:
import random
from datasets import load_dataset
from transformers import AutoTokenizer
your_dataset = load_dataset("ag_news", split="test")
id2labels = ["World", "Sports", "Business", "Sci/Tech"]
your_dataset = your_dataset.map(lambda x: {"class": id2labels[x["label"]]}, remove_columns=["label"])
print(your_dataset[0])
# the relevant code
t = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
template = "This example is {}."
def create_input_sequence(sample):
text = sample["text"]
label = sample["class"][0]
contradiction_label = random.choice([x for x in id2labels if x!=label])
encoded_sequence = t(text*2, [template.format(label), template.format(contradiction_label)])
encoded_sequence["labels"] = [2,0]
encoded_sequence["input_sentence"] = t.batch_decode(encoded_sequence.input_ids)
return encoded_sequence
train_dataset = your_dataset.map(create_input_sequence, batched=True, batch_size=1, remove_columns=["class", "text"])
print(train_dataset[0])
Output:
{'text': "Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
'class': 'Business'}
{'input_ids': [0, 597, 12541, 13, 255, 234, 4931, 71, 1431, 1890, 2485, 4561, 1138, 23, 6980, 1437, 1437, 188, 1250, 224, 51, 32, 128, 7779, 19051, 108, 71, 1431, 19, 35876, 4095, 933, 1853, 18059, 922, 4, 2, 2, 713, 1246, 16, 2090, 4, 2],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
'labels': 2,
'input_sentence': "<s>Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.</s></s>This example is Business.</s>"}
Robustness
Finetuning will obviously reduce the robustness (i.e. the ability to provide decent results for classes that weren't part of your fine-tuning dataset) of your model. To avoid that you could try:
- To stop training before conversion and check if the performance is still sufficient for your needs.
- WiSE-FT proposed by Wortsmann et. al. Pseudocode is shown in appendix A.
id2labels
Should not be integers? – Bharat