I am trying to find out, how exactly does BatchNormalization layer behave in TensorFlow. I came up with the following piece of code which to the best of my knowledge should be a perfectly valid keras model, however the mean and variance of BatchNormalization doesn't appear to be updated.
From docs https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
I expect the model to return a different value with each subsequent predict call. What I see, however, are the exact same values returned 10 times. Can anyone explain to me why does the BatchNormalization layer not update its internal values?
import tensorflow as tf
import numpy as np
if __name__ == '__main__':
np.random.seed(1)
x = np.random.randn(3, 5) * 5 + 0.3
bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
z = input = tf.keras.layers.Input([5])
z = bn(z)
model = tf.keras.Model(inputs=input, outputs=z)
for i in range(10):
print(x)
print(model.predict(x))
print()
I use TensorFlow 2.1.0
fit()
is called,BatchNormalization
behaves similar to whentraining=True
is set, so the layer learns the mean and the variance of the current batch. When callingpredict()
orevaluate()
, the layer is locked (similar to settingtraining=False
) and the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training. – Loya