Calling "fit" multiple times in Keras
Asked Answered
U

2

52

I've working on a CNN over several hundred GBs of images. I've created a training function that bites off 4Gb chunks of these images and calls fit over each of these pieces. I'm worried that I'm only training on the last piece on not the entire dataset.

Effectively, my pseudo-code looks like this:

DS = lazy_load_400GB_Dataset()
for section in DS:
    X_train = section.images
    Y_train = section.classes

    model.fit(X_train, Y_train, batch_size=16, nb_epoch=30)

I know that the API and the Keras forums say that this will train over the entire dataset, but I can't intuitively understand why the network wouldn't relearn over just the last training chunk.

Some help understanding this would be much appreciated.

Best, Joe

Unabridged answered 1/9, 2016 at 5:2 Comment(1)
creating a custom generator class and using it with fit_generator would avoid the whole problem of calling fit multiple times.Enlist
C
40

For datasets that do not fit into memory, there is an answer in the Keras Documentation FAQ section

You can do batch training using model.train_on_batch(X, y) and model.test_on_batch(X, y). See the models documentation.

Alternatively, you can write a generator that yields batches of training data and use the method model.fit_generator(data_generator, samples_per_epoch, nb_epoch).

You can see batch training in action in our CIFAR10 example.

So if you want to iterate your dataset the way you are doing, you should probably use model.train_on_batch and take care of the batch sizes and iteration yourself.

One more thing to note is that you should make sure the order in which the samples you train your model with is shuffled after each epoch. The way you have written the example code seems to not shuffle the dataset. You can read a bit more about shuffling here and here

Claque answered 1/9, 2016 at 12:47 Comment(3)
I understand we can use train_on_batch, but I still do not understand why the OP's original code would not work. Does fit() update the model in each iteration of data feeds as well?Clarey
Given the answer by @Rate below, it's unclear what the difference is between multiple calls to model.fit or model.train_on_batch is. Is there one?Barbette
model.fit manages internally the inputs and outputs from the dataset you provide, splitting them into batches and training each batch step by step, while also reporting the progress and having support for custom callbacks to execute during training. model.train_on_batch on the other hand simply takes one batch of inputs and ouputs and trains the model for a single step.Claque
R
53

This question was raised at the Keras github repository in Issue #4446: Quick Question: can a model be fit for multiple times? It was closed by François Chollet with the following statement:

Yes, successive calls to fit will incrementally train the model.

So, yes, you can call fit multiple times.

Rate answered 14/1, 2018 at 18:41 Comment(2)
The statement "Yes, successive calls to fit will incrementally train the model" seems correct, but when I train my model, with successive calls to fit, what I see is the absolute first call, takes a while to get upto my usual val_acc: 0.9x for my dataset, every subsequent is faster than this initial call, but each time fit is called, I see val_acc drop to around 0.05 everytime, before going back up to 90%. if it is incrementally training the model, why is this happening?Barnyard
I would like to hear an answer to this question too.Aurelie
C
40

For datasets that do not fit into memory, there is an answer in the Keras Documentation FAQ section

You can do batch training using model.train_on_batch(X, y) and model.test_on_batch(X, y). See the models documentation.

Alternatively, you can write a generator that yields batches of training data and use the method model.fit_generator(data_generator, samples_per_epoch, nb_epoch).

You can see batch training in action in our CIFAR10 example.

So if you want to iterate your dataset the way you are doing, you should probably use model.train_on_batch and take care of the batch sizes and iteration yourself.

One more thing to note is that you should make sure the order in which the samples you train your model with is shuffled after each epoch. The way you have written the example code seems to not shuffle the dataset. You can read a bit more about shuffling here and here

Claque answered 1/9, 2016 at 12:47 Comment(3)
I understand we can use train_on_batch, but I still do not understand why the OP's original code would not work. Does fit() update the model in each iteration of data feeds as well?Clarey
Given the answer by @Rate below, it's unclear what the difference is between multiple calls to model.fit or model.train_on_batch is. Is there one?Barbette
model.fit manages internally the inputs and outputs from the dataset you provide, splitting them into batches and training each batch step by step, while also reporting the progress and having support for custom callbacks to execute during training. model.train_on_batch on the other hand simply takes one batch of inputs and ouputs and trains the model for a single step.Claque

© 2022 - 2024 — McMap. All rights reserved.