I am following the Transfer learning and fine-tuning guide on the official TensorFlow website. It points out that during fine-tuning, batch normalization layers should be in inference mode:
Important notes about
BatchNormalization
layerMany image models contain
BatchNormalization
layers. That layer is a special case on every imaginable count. Here are a few things to keep in mind.
BatchNormalization
contains 2 non-trainable weights that get updated during training. These are the variables tracking the mean and variance of the inputs.- When you set
bn_layer.trainable = False
, theBatchNormalization
layer will run in inference mode, and will not update its mean & variance statistics. This is not the case for other layers in general, as weight trainability & inference/training modes are two orthogonal concepts. But the two are tied in the case of theBatchNormalization
layer.- When you unfreeze a model that contains
BatchNormalization
layers in order to do fine-tuning, you should keep theBatchNormalization
layers in inference mode by passingtraining=False
when calling the base model. Otherwise the updates applied to the non-trainable weights will suddenly destroy what the model has learned.You'll see this pattern in action in the end-to-end example at the end of this guide.
Even tho, some other sources, for example this article (titled Transfer Learning with ResNet), says something completely different:
for layer in resnet_model.layers: if isinstance(layer, BatchNormalization): layer.trainable = True else: layer.trainable = False
ANYWAY, I know that there is a difference between training
and trainable
parameters in TensorFlow.
I am loading my model from file, as so:
model = tf.keras.models.load_model(path)
And I am unfreezing (or actually freezing the rest) some of the top layers in this way:
model.trainable = True
for layer in model.layers:
if layer not in model.layers[idx:]:
layer.trainable = False
NOW about batch normalization layers: I can either do:
for layer in model.layers:
if isinstance(layer, keras.layers.BatchNormalization):
layer.trainable = False
or
for layer in model.layers:
if layer.name.startswith('bn'):
layer.call(layer.input, training=False)
Which one should I do? And whether finally it is better to freeze batch norm layer or not?