If you're using a custom training loop, you can use a collections.deque
, which is a "rolling" list which can be appended, and the left-hand items gets popped out when the list is longer than maxlen
. Here's the line:
loss_history = deque(maxlen=early_stopping + 1)
for epoch in range(epochs):
fit(epoch)
loss_history.append(test_loss.result().numpy())
if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
break
Here's a full example:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque
data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)
data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))
train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)
model = tf.keras.models.Sequential([
Dense(8, activation='relu'),
Dense(16, activation='relu'),
Dense(info.features['label'].num_classes)])
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
logits = model(inputs, training=True)
loss = loss_object(labels, logits)
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_acc(labels, logits)
@tf.function
def test_step(inputs, labels):
logits = model(inputs, training=False)
loss = loss_object(labels, logits)
test_loss(loss)
test_acc(labels, logits)
def fit(epoch):
template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
'Train Acc {:.2f} Test Acc {:.2f}'
train_loss.reset_states()
test_loss.reset_states()
train_acc.reset_states()
test_acc.reset_states()
for X_train, y_train in train_dataset:
train_step(X_train, y_train)
for X_test, y_test in test_dataset:
test_step(X_test, y_test)
print(template.format(
epoch + 1,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()
))
def main(epochs=50, early_stopping=10):
loss_history = deque(maxlen=early_stopping + 1)
for epoch in range(epochs):
fit(epoch)
loss_history.append(test_loss.result().numpy())
if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
print(f'\nEarly stopping. No validation loss '
f'improvement in {early_stopping} epochs.')
break
if __name__ == '__main__':
main(epochs=250, early_stopping=10)
Epoch 1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch 2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch 3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch 4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch 5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87
Early stopping. No validation loss improvement in 10 epochs.