tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time
Asked Answered
D

2

15

I am trying to create a CNN model using RandomSearch but its very slow and pops this error tensorflow:Callback method on_train_batch_end is slow compared to the batch time I am running my code in google colab with hardware acceleration set on gpu this is my code

def model_builder(hp):
    model=Sequential([
        Conv2D(filters=hp.Int('conv_1_filter',min_value=32,max_value=128,step=32),
               kernel_size=hp.Int('conv_1_filter',min_value=2,max_value=3,step=1),
               activation='relu',
               padding='same',
               input_shape=(200,200,3)),
        MaxPooling2D(pool_size=(2,2),strides=(2,2)),
        
        Conv2D(filters=hp.Int('conv_2_filter',min_value=32,max_value=128,step=32),
               kernel_size=hp.Int('conv_2_filter',min_value=2,max_value=3,step=1),
               padding='same',
               activation='relu'),
        MaxPooling2D(pool_size=(2,2),strides=(2,2)),
        
        Flatten(),
        
        Dense(units=hp.Int('dense_1_units',min_value=32,max_value=512,step=128),
              activation='relu'),
        
        Dense(units=10,
              activation='softmax')
               
    ])
    
    model.compile(optimizer=Adam(hp.Choice('learning_rate',values=[1e-1,1e-3,3e-2])),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

then RandomSearch and Fit

tuner=RandomSearch(model_builder,
                   objective='val_accuracy',
                   max_trials=2,
                   directory='projects',
                   project_name='Hypercars CNN'
                  )
tuner.search(X_train,Y_train,epochs=2,validation_split=0.2)
Domoniquedomph answered 9/2, 2021 at 4:56 Comment(0)
L
14

This is caused when other operations which run at the end of each batch consumes more time than the batch itself. It could be that you have really small batches i.e. any operation that is slower in comparison to your original batches.

Increasing the batch size should solve this. Alternatively, you can use_multiprocessing = True in model.fit() and select the appropriate number of workers to generate your training batches more efficiently - but this only works for datasets that use a generator or keras.utils.Sequence.

Two threads talking about this issue:

  1. Thread 1
  2. Thread 2
Launderette answered 9/2, 2021 at 7:1 Comment(1)
use_mutiprocessing = TrueFanatical
S
2

use_multiprocessing = True can work in removing that warning but another warning pops up relating using multiprocessing in tf2.

Strow answered 14/7, 2022 at 12:37 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.