Early stopping in Bert Trainer instances
Asked Answered
C

1

29

I am fine-tuning a BERT model for a multiclass classification task. My problem is that I don't know how to add "early stopping" to those Trainer instances. Any ideas?

Coverture answered 7/9, 2021 at 11:2 Comment(0)
H
66

There are a couple of modifications you need to perform, prior to correctly using the EarlyStoppingCallback().

from transformers import EarlyStoppingCallback, IntervalStrategy
...
...
# Defining the TrainingArguments() arguments
args = TrainingArguments(
   f"training_with_callbacks",
   evaluation_strategy = IntervalStrategy.STEPS, # "steps"
   eval_steps = 50, # Evaluation and Save happens every 50 steps
   save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
   learning_rate=2e-5,
   per_device_train_batch_size=batch_size,
   per_device_eval_batch_size=batch_size,
   num_train_epochs=5,
   weight_decay=0.01,
   push_to_hub=False,
   metric_for_best_model = 'f1',
   load_best_model_at_end=True)

You need to:

  1. Use load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
  2. evaluation_strategy = 'steps' or IntervalStrategy.STEPS instead of 'epoch'.
  3. eval_steps = 50 (evaluate the metrics after N steps).
  4. metric_for_best_model = 'f1'

In your Trainer():

trainer = Trainer(
    model,
    args,
    ...
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

Of course, when you use compute_metrics(), for example it can be a function like:

def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred)
    precision = precision_score(y_true=labels, y_pred=pred)
    f1 = f1_score(y_true=labels, y_pred=pred)    
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

The return of the compute_metrics() should be a dictionary and you can access whatever metric you want/compute inside the function and return.

Note: In newer transformers version, the usage of Enum IntervalStrategy.steps is recommended (see TrainingArguments()) instead of plain steps string, the latter being soon subject to deprecation.

Hollenbeck answered 7/9, 2021 at 11:9 Comment(8)
apart from f1, other metric like bleu can also be used along with early stopping.Spanker
Yes, definitely, depending on the task at handHollenbeck
Why do I need the evaluation strategy to be steps? Why can't it be epoch?Dumond
Can't say much about that, it's the way the framework is developed @omermazigHollenbeck
But you can easily simulate the epoch, since batch_size * number_steps == epoch length. (800 samples, you have BS == 8 and 100 steps).Hollenbeck
it looks like load_best_model_at_end=True doesn't work correctly when 8 bit quantization is used. Do you know any solution for this?Bevan
No, I haven't tried that, you can open an issue on their repo once you determined that this is the problem.Hollenbeck
Evaluation strategy can also be steps. we can set the following arguments: evaluation_strategy="epoch", save_strategy="epoch"Shutout

© 2022 - 2024 — McMap. All rights reserved.