I am fine-tuning a BERT model for a multiclass classification task. My problem is that I don't know how to add "early stopping" to those Trainer instances. Any ideas?
Early stopping in Bert Trainer instances
There are a couple of modifications you need to perform, prior to correctly using the EarlyStoppingCallback()
.
from transformers import EarlyStoppingCallback, IntervalStrategy
...
...
# Defining the TrainingArguments() arguments
args = TrainingArguments(
f"training_with_callbacks",
evaluation_strategy = IntervalStrategy.STEPS, # "steps"
eval_steps = 50, # Evaluation and Save happens every 50 steps
save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
weight_decay=0.01,
push_to_hub=False,
metric_for_best_model = 'f1',
load_best_model_at_end=True)
You need to:
- Use
load_best_model_at_end = True
(EarlyStoppingCallback()
requires this to beTrue
). evaluation_strategy
='steps'
orIntervalStrategy.STEPS
instead of'epoch'
.eval_steps = 50
(evaluate the metrics afterN steps
).metric_for_best_model = 'f1'
In your Trainer()
:
trainer = Trainer(
model,
args,
...
compute_metrics=compute_metrics,
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)
Of course, when you use compute_metrics()
, for example it can be a function like:
def compute_metrics(p):
pred, labels = p
pred = np.argmax(pred, axis=1)
accuracy = accuracy_score(y_true=labels, y_pred=pred)
recall = recall_score(y_true=labels, y_pred=pred)
precision = precision_score(y_true=labels, y_pred=pred)
f1 = f1_score(y_true=labels, y_pred=pred)
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
The return of the compute_metrics()
should be a dictionary and you can access whatever metric you want/compute inside the function and return.
Note: In newer transformers
version, the usage of Enum
IntervalStrategy.steps
is recommended (see TrainingArguments()
) instead of plain steps
string, the latter being soon subject to deprecation.
Yes, definitely, depending on the task at hand –
Hollenbeck
Why do I need the evaluation strategy to be
steps
? Why can't it be epoch
? –
Dumond Can't say much about that, it's the way the framework is developed @omermazig –
Hollenbeck
But you can easily simulate the epoch, since batch_size * number_steps == epoch length. (800 samples, you have BS == 8 and 100 steps). –
Hollenbeck
it looks like load_best_model_at_end=True doesn't work correctly when 8 bit quantization is used. Do you know any solution for this? –
Bevan
No, I haven't tried that, you can open an issue on their repo once you determined that this is the problem. –
Hollenbeck
Evaluation strategy can also be steps. we can set the following arguments:
evaluation_strategy="epoch", save_strategy="epoch"
–
Shutout © 2022 - 2024 — McMap. All rights reserved.
f1
, other metric likebleu
can also be used along with early stopping. – Spanker