How to freeze batch-norm layers during Transfer-learning
Asked Answered
W

3

5

I am following the Transfer learning and fine-tuning guide on the official TensorFlow website. It points out that during fine-tuning, batch normalization layers should be in inference mode:

Important notes about BatchNormalization layer

Many image models contain BatchNormalization layers. That layer is a special case on every imaginable count. Here are a few things to keep in mind.

  • BatchNormalization contains 2 non-trainable weights that get updated during training. These are the variables tracking the mean and variance of the inputs.
  • When you set bn_layer.trainable = False, the BatchNormalization layer will run in inference mode, and will not update its mean & variance statistics. This is not the case for other layers in general, as weight trainability & inference/training modes are two orthogonal concepts. But the two are tied in the case of the BatchNormalization layer.
  • When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing training=False when calling the base model. Otherwise the updates applied to the non-trainable weights will suddenly destroy what the model has learned.

You'll see this pattern in action in the end-to-end example at the end of this guide.

Even tho, some other sources, for example this article (titled Transfer Learning with ResNet), says something completely different:

for layer in resnet_model.layers:
    if isinstance(layer, BatchNormalization):
        layer.trainable = True
    else:
        layer.trainable = False

ANYWAY, I know that there is a difference between training and trainable parameters in TensorFlow.

I am loading my model from file, as so:

model = tf.keras.models.load_model(path)

And I am unfreezing (or actually freezing the rest) some of the top layers in this way:

model.trainable = True

for layer in model.layers:
    if layer not in model.layers[idx:]:
        layer.trainable = False

NOW about batch normalization layers: I can either do:

for layer in model.layers:
    if isinstance(layer, keras.layers.BatchNormalization):
      layer.trainable = False

or

  for layer in model.layers:
    if layer.name.startswith('bn'):
      layer.call(layer.input, training=False)

Which one should I do? And whether finally it is better to freeze batch norm layer or not?

Weightless answered 8/6, 2021 at 11:10 Comment(0)
L
4

Not sure about the training vs trainable difference, but personally I've gotten good results settings trainable = False.

Now as to whether to freeze them in the first place: I've had good results with not freezing them. The reasoning is simple, the batch norm layer learns the moving average of the initial training data. This may be cats, dogs, humans, cars e.t.c. But when you're transfer learning, you could be moving to a completely different domain. The moving averages of this new domain of images are far different from the prior dataset.

By unfreezing those layers and freezing the CNN layers, my model saw a 6-7% increase in accuracy (82 -> 89% ish). My dataset was far different from the inital Imagenet dataset that efficientnet was trained on.

P.S. Depending on how you plan on running the mode post training, I would advise you to freeze the batch norm layers once the model is trained. For some reason, if you ran the model online (1 image at a time), the batch norm would get all funky and give irregular results. Freezing them post training fixed the issue for me.

Luminary answered 9/10, 2021 at 5:47 Comment(0)
A
3

Just to add to @luciano-dourado answer;

In my case, I started by following the Transfer Learning guide as is, that is, freezing BN layers throughout the entire training (classifier + fine-tuning). What I saw is that training the classifier worked without problems but as soon as I started fine-tuning, the loss went to NaN after a few batches.

After running the usual checks: input data without NaNs, loss functions yielding correct values, etc. I checked if BN layers were running in inference mode (trainable = False).

But in my case, the dataset was so different to ImageNet that I needed to do the contrary, set all trainable BN attributes to True. I found this empirically just as @zwang commented. Just remember to freeze them after training, before you deploy the model for inference.

By the way, just as an informative note, ResNet50V2, for example, has a total 49 BN layers of which only 16 are pre-activations BNs. This means that the remaining 33 layers were updating their mean and variance values.

Yet another case where one has to run several empirical tests to find out why the "standard" approach does not work in his/her case. I guess this further reinforces the importance of data in Deep Learning :)

Anyway answered 8/4, 2022 at 14:28 Comment(0)
U
2

Use the code below to see whether the batch norm layer are being freezed or not. It will not only print the layer names but whether they are trainable or not.

def print_layer_trainable(conv_model):
    for layer in conv_model.layers:
        print("{0}:\t{1}".format(layer.trainable, layer.name))

In this case i have tested your method but did not freezed my model's batch norm layers.

for layer in model.layers:
    if isinstance(layer, keras.layers.BatchNormalization):
      layer.trainable = False

The code below worked nice for me. In my case the model is a ResNetV2 and the batch norm layers are named with the suffix "preact_bn". By using the code above for printing layers you can see how the batch norm layers are named and configure as you want.

  for layer in new_model.layers[:]:          
    if ('preact_bn' in layer.name):
      trainable = False
    else:
      trainable = True
    layer.trainable = trainable

Ulotrichous answered 15/2, 2022 at 18:49 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.