early stopping in PyTorch
Asked Answered
V

4

32

I tried to implement an early stopping function to avoid my neural network model overfit. I'm pretty sure that the logic is fine, but for some reason, it doesn't work. I want that when the validation loss is greater than the training loss over some epochs, the early stopping function returns True. But it returns False all the time, even though the validation loss becomes a lot greater than the training loss. Could you see where is the problem, please?

early stopping function

def early_stopping(train_loss, validation_loss, min_delta, tolerance):

    counter = 0
    if (validation_loss - train_loss) > min_delta:
        counter +=1
        if counter >= tolerance:
          return True

calling the function during the training

for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)

    # validation 

    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
      print("We are at epoch:", i)
      break

EDIT: The train and validation loss: enter image description here enter image description here

EDIT2:

def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
    preds = []
    train_loss =  []
    validation_loss = []
    min_delta = 5
    

    for e in range(epochs):
        
        print(f"Epoch {e+1}")
        epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
        train_loss.append(epoch_train_loss)

        # validation 
        with torch.no_grad(): 
           epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
           validation_loss.append(epoch_validate_loss)
        
        # early stopping
        early_stopping = EarlyStopping(tolerance=2, min_delta=5)
        early_stopping(epoch_train_loss, epoch_validate_loss)
        if early_stopping.early_stop:
            print("We are at epoch:", e)
            break

    return train_loss, validation_loss
Verbal answered 25/4, 2022 at 11:42 Comment(0)
A
20

The problem with your implementation is that whenever you call early_stopping() the counter is re-initialized with 0.

Here is working solution using an oo-oriented approch with __call__() and __init__() instead:

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True

Call it like that:

early_stopping = EarlyStopping(tolerance=5, min_delta=10)

for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)

    # validation 
    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    early_stopping(epoch_train_loss, epoch_validate_loss)
    if early_stopping.early_stop:
      print("We are at epoch:", i)
      break

Example:

early_stopping = EarlyStopping(tolerance=2, min_delta=5)

train_loss = [
    642.14990234,
    601.29278564,
    561.98400879,
    530.01501465,
    497.1098938,
    466.92709351,
    438.2364502,
    413.76028442,
    391.5090332,
    370.79074097,
]
validate_loss = [
    509.13619995,
    497.3125,
    506.17315674,
    497.68960571,
    505.69918823,
    459.78610229,
    480.25592041,
    418.08630371,
    446.42675781,
    372.09902954,
]

for i in range(len(train_loss)):

    early_stopping(train_loss[i], validate_loss[i])
    print(f"loss: {train_loss[i]} : {validate_loss[i]}")
    if early_stopping.early_stop:
        print("We are at epoch:", i)
        break

Output:

loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6
Axillary answered 25/4, 2022 at 12:12 Comment(1)
I edited my post by adding the train and validation loss. Around epoch 50 we can see that the validation loss is increasing. Here, I had tolerance=2 and min_delta=5. It should have ended the training but it continued till the last epoch.Verbal
B
73

Although @KarelZe's response solves your problem sufficiently and elegantly, I want to provide an alternative early stopping criterion that is arguably better.

Your early stopping criterion is based on how much (and for how long) the validation loss diverges from the training loss. This will break when the validation loss is indeed decreasing but is generally not close enough to the training loss. The goal of training a model is to encourage the reduction of validation loss and not the reduction in the gap between training loss and validation loss.

Hence, I would argue that a better early stopping criterion would be watch for the trend in validation loss alone, i.e., if the training is not resulting in lowering of the validation loss then terminate it. Here's an example implementation:

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

Here's how you'd use it:

early_stopper = EarlyStopper(patience=3, min_delta=10)
for epoch in np.arange(n_epochs):
    train_loss = train_one_epoch(model, train_loader)
    validation_loss = validate_one_epoch(model, validation_loader)
    if early_stopper.early_stop(validation_loss):             
        break
Byebye answered 13/9, 2022 at 14:17 Comment(10)
Thank you very much for your answer. It's a new idea and so amazing. So kind of you!Verbal
Thanks for this solution! I was just wondering why earlier solutions were checking the gap between train and val? That should not be the criteria isnt it? Or am I missing on something?Uis
I don't know why, but I believe it could just have resulted out of a cognitive bias from seeing typical training and validation curves shown in texts and blogs, where overfitting is always identified when the validation loss curve starts to deviate from the training loss curve.Byebye
Shouldn't the min_delta be used in deciding a model's improvement? keras.io/api/callbacks/early_stoppingEnvirons
@Environs that's exactly what min_delta here is being used for. See the check ` if validation_loss > (self.min_validation_loss + self.min_delta)`. Or am I missing something here?Byebye
@Byebye For example, self.min_validation_loss= 0.5, validation_loss=0.45 and min_delta=0.1, the class will determine the model has improved which was not supposed to be.Environs
@Environs if validation loss decreases even a slight amount compared to min validation loss then I'd infer that the training is successfully moving forward. This is because I find it a bit difficult to judge to what extent each training step should decrease the validation loss. I would only add a threshold for the case the the loss doesn't decrease (i.e., stays flat / increases) compared to min validation loss. Technically, one can do both, and wouldn't be wrong.Byebye
@Byebye are there any major issues with using a rate of change of error instead of an absolute error metric? Eg: early_stop if rate of change of error < 5% for 3 consequetive validation sequences? Or is it preferable to use an absolute metric?Seethe
@Seethe I’d consider a learning rate scheduler to address rate of change and keep the absolute metric for the early stopperByebye
This is brilliant. I second @Byebye and calculating the first derivative of the validation loss and even the second derivative could provide better stopping parameters.Churchman
A
20

The problem with your implementation is that whenever you call early_stopping() the counter is re-initialized with 0.

Here is working solution using an oo-oriented approch with __call__() and __init__() instead:

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True

Call it like that:

early_stopping = EarlyStopping(tolerance=5, min_delta=10)

for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)

    # validation 
    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    early_stopping(epoch_train_loss, epoch_validate_loss)
    if early_stopping.early_stop:
      print("We are at epoch:", i)
      break

Example:

early_stopping = EarlyStopping(tolerance=2, min_delta=5)

train_loss = [
    642.14990234,
    601.29278564,
    561.98400879,
    530.01501465,
    497.1098938,
    466.92709351,
    438.2364502,
    413.76028442,
    391.5090332,
    370.79074097,
]
validate_loss = [
    509.13619995,
    497.3125,
    506.17315674,
    497.68960571,
    505.69918823,
    459.78610229,
    480.25592041,
    418.08630371,
    446.42675781,
    372.09902954,
]

for i in range(len(train_loss)):

    early_stopping(train_loss[i], validate_loss[i])
    print(f"loss: {train_loss[i]} : {validate_loss[i]}")
    if early_stopping.early_stop:
        print("We are at epoch:", i)
        break

Output:

loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6
Axillary answered 25/4, 2022 at 12:12 Comment(1)
I edited my post by adding the train and validation loss. Around epoch 50 we can see that the validation loss is increasing. Here, I had tolerance=2 and min_delta=5. It should have ended the training but it continued till the last epoch.Verbal
B
4

It may help someone like myself, I would like to add upon previous answers.

Both of the answers provided have different interpretations of the min_delta parameter. In @KarelZe's answer, min_delta is used as the gap between train_loss and validation_loss:

if (validation_loss - train_loss) > self.min_delta: 
        self.counter +=1

On the other hand, in @isle_of_gods' answer, min_delta is used to increment the counter when the new validation loss is at least min_delta greater than the current minimum validation loss:

elif validation_loss > (self.min_validation_loss + self.min_delta):
        self.counter += 1

Although non of these answers are wrong, since it depends on ones' needs, I think it is more intuitive to consider min_delta as the minimum change required to consider the model as improving. The documentation from Keras, which is equally popular as PyTorch, defines the min_delta parameter in their early stopping mechanism as follows:

min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

That means, any decrease validation loss value will not be counted a decrease unless the decrease is larger than min_delta

To align with the Keras documentation, @isle_of_gods' code can be modified as follows:

class ValidationLossEarlyStopping:
    def __init__(self, patience=1, min_delta=0.0):
        self.patience = patience  # number of times to allow for no improvement before stopping the execution
        self.min_delta = min_delta  # the minimum change to be counted as improvement
        self.counter = 0  # count the number of times the validation accuracy not improving
        self.min_validation_loss = np.inf

    # return True when validation loss is not decreased by the `min_delta` for `patience` times 
    def early_stop_check(self, validation_loss):
        if ((validation_loss+self.min_delta) < self.min_validation_loss):
            self.min_validation_loss = validation_loss
            self.counter = 0  # reset the counter if validation loss decreased at least by min_delta
        elif ((validation_loss+self.min_delta) > self.min_validation_loss):
            self.counter += 1 # increase the counter if validation loss is not decreased by the min_delta
            if self.counter >= self.patience:
                return True
        return False
Babbage answered 3/7, 2023 at 7:6 Comment(1)
One downside of this interpretation is that if min_delta=0, the program will never terminate if loss = 0.Remunerative
R
1

Yet another implementation (but different from @Nawras in that multiple epochs of the same loss value count as "no improvement"):

class EarlyStopping:
    def __init__(self, *, min_delta=0.0, patience=0):
        self.min_delta = min_delta
        self.patience = patience
        self.best = float("inf")
        self.wait = 0
        self.done = False

    def step(self, current):
        self.wait += 1

        if current < self.best - self.min_delta:
            self.best = current
            self.wait = 0
        elif self.wait >= self.patience:
            self.done = True

        return self.done

Example usage:

import random

random.seed(1234)

def log():
    print(
        f"{epoch=:03}  "
        f"{loss=:.02f}  "
        f"best={early_stopping.best:.02f}  "
        f"wait={early_stopping.wait}"
    )

early_stopping = EarlyStopping(patience=3)

for epoch in range(1, 1000):
    loss = 5 / epoch + random.random()

    if early_stopping.step(loss):
        log()
        break

    log()

Output:

epoch=001  loss=5.97  best=5.97  wait=0
epoch=002  loss=2.94  best=2.94  wait=0
epoch=003  loss=1.67  best=1.67  wait=0
epoch=004  loss=2.16  best=1.67  wait=1
epoch=005  loss=1.94  best=1.67  wait=2
epoch=006  loss=1.42  best=1.42  wait=0
epoch=007  loss=1.39  best=1.39  wait=0
epoch=008  loss=0.71  best=0.71  wait=0
epoch=009  loss=1.32  best=0.71  wait=1
epoch=010  loss=0.74  best=0.71  wait=2
epoch=011  loss=0.49  best=0.49  wait=0
epoch=012  loss=1.21  best=0.49  wait=1
epoch=013  loss=0.73  best=0.49  wait=2
epoch=014  loss=0.98  best=0.49  wait=3
Remunerative answered 17/7 at 21:23 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.