How to use repeat() function when building data in Keras?
Asked Answered
T

1

14

I am training a binary classifier on a dataset of cats and dogs:
Total Dataset: 10000 images
Training Dataset: 8000 images
Validation/Test Dataset: 2000 images

The Jupyter notebook code:

# Part 2 - Fitting the CNN to the images
train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('dataset/training_set',
                                                 target_size = (64, 64),
                                                 batch_size = 32,
                                                 class_mode = 'binary')

test_set = test_datagen.flow_from_directory('dataset/test_set',
                                            target_size = (64, 64),
                                            batch_size = 32,
                                            class_mode = 'binary')

history = model.fit_generator(training_set,
                              steps_per_epoch=8000,
                              epochs=25,
                              validation_data=test_set,
                              validation_steps=2000)

I trained it on a CPU without a problem but when I run on GPU it throws me this error:

Found 8000 images belonging to 2 classes.
Found 2000 images belonging to 2 classes.
WARNING:tensorflow:From <ipython-input-8-140743827a71>:23: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
Train for 8000 steps, validate for 2000 steps
Epoch 1/25
 250/8000 [..............................] - ETA: 21:50 - loss: 7.6246 - accuracy: 0.5000
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 200000 batches). You may need to use the repeat() function when building your dataset.
 250/8000 [..............................] - ETA: 21:52 - loss: 7.6246 - accuracy: 0.5000

I would like to know how to use the repeat() function in keras using Tensorflow 2.0?

Tannie answered 3/3, 2020 at 14:28 Comment(5)
can you show the output of print(type(training_set))?Frager
batch_size is set to 32, while steps_per_epoch is 8K, which (I guess) means that keras will expect a total of 8K*32 samples, which you do not have. Try setting steps_per_epoch = 8K/32.Giglio
print(type(training_set)) shows this output: <class 'keras_preprocessing.image.directory_iterator.DirectoryIterator'>Tannie
@AlexKreimer Thanks, I could do that but I would like to show all my training images to the model before moving to next epoch, which is intuitively easy to understand. Also, I wanted to know how to use the repeat() function?Tannie
@Tannie steps_per_epoch is measured in batches (stats.stackexchange.com/questions/153531/…) , so your model will go over all examples in a single epoch. repeat belongs to a (different) world of tf.data (tensorflow.org/api_docs/python/tf/data/Dataset). Since your code does not seem to construct a tf dataset I don't see how it would be useful in this case.Giglio
S
24

Your problem stems from the fact that the parameters steps_per_epoch and validation_steps need to be equal to the total number of data points divided by the batch_size.

Your code would work in Keras 1.X, prior to August 2017.

Change your model.fit() function to:

history = model.fit_generator(training_set,
                              steps_per_epoch=int(8000/batch_size),
                              epochs=25,
                              validation_data=test_set,
                              validation_steps=int(2000/batch_size))

As of TensorFlow 2.1, fit_generator() is being deprecated. You can use .fit() method also on generators.

TensorFlow >= 2.1 code:

history = model.fit(training_set.repeat(),
                    steps_per_epoch=int(8000/batch_size),
                    epochs=25,
                    validation_data=test_set.repeat(),
                    validation_steps=int(2000/batch_size))

Notice that int(8000/batch_size) is equivalent to 8000 // batch_size (integer division)

Strobile answered 3/3, 2020 at 14:48 Comment(9)
@TimbusCalin I'm currently facing the same issue, tried the answer you gave but the issue still persists with the same error 'Your input ran out of data'. Is it possible to use steps_per_epoch=int(training_set/batch_size) ?Subkingdom
That's another issue. That's because you use tf.data.Dataset() and you forgot to put .repeat() on the training set and on the validation set, that's why your error happens.Strobile
@TimbusCalin I'm not using tf.data.Dataset(), hence I didn't put .repeat() anywhere. I read some other post's and they recommended using .repeat() but they never mentioned where it should be done. Should this be done when creating iterations or when fitting the model? Sorry for asking to much but this question is the closest to what I was searching for and as you can identify I'm new to tenserflow.Subkingdom
You can put it just like above on training_set.repeat()Strobile
@TimbusCalin I am experiencing the same issue - "Your input ran out of data". It can indeed by solved by this answer. I am wondering though: Why does this not happen when I train on the CPU but only when I train on the GPU (as mentioned in the OP)?Encephalic
You could also omit the "steps_per_epoch" parameter and keras will figure out the number of steps from the batch_sizePetcock
@Petcock You can omit it if your DataGenerator is a subclassing the Sequence() class, indeed it is optional for Sequence() but not for the other DataGeneratos.Strobile
For me, generator.repeat() raises AttributeError: 'generator' object has no attribute 'repeat'.Mcginn
Why do you need to use repeat() here?Venusian

© 2022 - 2024 — McMap. All rights reserved.