How to finetune a zero-shot model for text classification
Asked Answered
C

1

2

I need a model that is able to classify text for an unknown number of classes (i.e. the number might grow over time). The entailment approach for zero-shot text classification seems to be the solution to my problem, the model I tried facebook/bart-large-mnli doesn't perform well on my annotated data. Is there a way to fine-tune it without losing the robustness of the model?

My dataset looks like this:

# http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html
World, "Afghan Army Dispatched to Calm Violence KABUL, Afghanistan - Government troops intervened in Afghanistan's latest outbreak of deadly fighting between warlords, flying from the capital to the far west on U.S. and NATO airplanes to retake an air base contested in the violence, officials said Sunday..."
Sports, "Johnson Helps D-Backs End Nine-Game Slide (AP) AP - Randy Johnson took a four-hitter into the ninth inning to help the Arizona Diamondbacks end a nine-game losing streak Sunday, beating Steve Trachsel and the New York Mets 2-0." 
Business, "Retailers Vie for Back-To-School Buyers (Reuters) Reuters - Apparel retailers are hoping their\back-to-school fashions will make the grade among\style-conscious teens and young adults this fall, but it could\be a tough sell, with students and parents keeping a tighter\hold on their wallets."

P.S.: This is an artificial question that was created because this topic came up in the comment section of this post which is related to this post.

Coming answered 9/5, 2023 at 23:1 Comment(0)
C
10

Concept explanation

Before I answer your question, it is crucial to understand how the entailment approach for zero-shot text classification works. This approach requires a model that was trained for NLI, which means, that it is able to determine if the hypothesis is:

  • supported,
  • not supported,
  • undetermined

by a given premise [1]. You can verify that for the model you mentioned with the following code:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
# It will output three logits
print(nli_model.classification_head.out_proj)
# Each vector corresponds to the following labels
print(nli_model.config.id2label)

Output:

Linear(in_features=1024, out_features=3, bias=True)
{0: 'contradiction', 1: 'neutral', 2: 'entailment'}

The entailment approach, proposed by Yin et. al, utilizes these NLI capabilities by using the text as premise and formulating a hypothesis for each possible class with the template:

"the text is about {}”

That means when you have a text and three potential classes, you will pass three sequences to the NLI model and compare the entailment logits to classify the text.

Finetuning

To fine-tune an NLI model on your annotated data, you, therefore, need to formulate your text classification task as an NLI task! That means, you need to generate premises and the labels need to be either contradiction or entailment. The contradiction label is included to avoid the model only seeing hypotheses that are entailed by their respective premise (i.e. the model needs to learn contraction to predict a low score for entailment for the zero-shot text classification task).

The following code shows you an example of how to prepare your dataset:

import random
from datasets import load_dataset
from transformers import  AutoTokenizer

your_dataset = load_dataset("ag_news", split="test")
id2labels = ["World", "Sports", "Business", "Sci/Tech"]
your_dataset = your_dataset.map(lambda x: {"class": id2labels[x["label"]]}, remove_columns=["label"])

print(your_dataset[0])

# the relevant code
t = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
template = "This example is {}."

def create_input_sequence(sample):
  text = sample["text"]
  label = sample["class"][0]
  contradiction_label = random.choice([x for x in id2labels if x!=label])

  encoded_sequence = t(text*2, [template.format(label), template.format(contradiction_label)])
  encoded_sequence["labels"] = [2,0]
  encoded_sequence["input_sentence"] = t.batch_decode(encoded_sequence.input_ids)

  return encoded_sequence

train_dataset = your_dataset.map(create_input_sequence, batched=True, batch_size=1, remove_columns=["class", "text"])
print(train_dataset[0])

Output:

{'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.", 
'class': 'Business'}

{'input_ids': [0, 597, 12541, 13, 255, 234, 4931, 71, 1431, 1890, 2485, 4561, 1138, 23, 6980, 1437, 1437, 188, 1250, 224, 51, 32, 128, 7779, 19051, 108, 71, 1431, 19, 35876, 4095, 933, 1853, 18059, 922, 4, 2, 2, 713, 1246, 16, 2090, 4, 2], 
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
'labels': 2, 
'input_sentence': "<s>Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.</s></s>This example is Business.</s>"}

Robustness

Finetuning will obviously reduce the robustness (i.e. the ability to provide decent results for classes that weren't part of your fine-tuning dataset) of your model. To avoid that you could try:

  • To stop training before conversion and check if the performance is still sufficient for your needs.
  • WiSE-FT proposed by Wortsmann et. al. Pseudocode is shown in appendix A.
Coming answered 9/5, 2023 at 23:1 Comment(7)
id2labels Should not be integers?Bharat
No text. You don't want to ask the model This example is 1. you want to ask This example is Sports. to stay close to NLI. @DolevMitzComing
Alright Thank you, After I implemented this as you said on my dataset, I'm trying again to train it and getting the following error expected sequence of length 151 at dim 1 (got 214) do you have any idea why?Bharat
Did you forget to pad your dataset? @DolevMitzComing
Ok I did pad and change to a proper tokenizer, I did everything now according to your answer here, and this error is still getting now, even with the proper dataset like yoursBharat
In the same way, you figured out that the pretrained checkpoint doesn't perform well with your dataset (i.e. the 0-shot pipeline returns a list of scores -> if the expected label has the highest score 1 else 0). @DolevMitzComing
I created one here @ComingBharat

© 2022 - 2024 — McMap. All rights reserved.