I'm wondering what the current available options are for simulating BatchNorm folding during quantization aware training in Tensorflow 2. Tensorflow 1 has the tf.contrib.quantize.create_training_graph
function which inserts FakeQuantization layers into the graph and takes care of simulating batch normalization folding (according to this white paper).
Tensorflow 2 has a tutorial on how to use quantization in their recently adopted tf.keras
API, but they don't mention anything about batch normalization. I tried the following simple example with a BatchNorm layer:
import tensorflow_model_optimization as tfmo
model = tf.keras.Sequential([
l.Conv2D(32, 5, padding='same', activation='relu', input_shape=input_shape),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Conv2D(64, 5, padding='same', activation='relu'),
l.BatchNormalization(), # BN!
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),
l.Dense(1024, activation='relu'),
l.Dropout(0.4),
l.Dense(num_classes),
l.Softmax(),
])
model = tfmo.quantization.keras.quantize_model(model)
It however gives the following exception:
RuntimeError: Layer batch_normalization:<class 'tensorflow.python.keras.layers.normalization.BatchNormalization'> is not supported. You can quantize this layer by passing a `tfmot.quantization.keras.QuantizeConfig` instance to the `quantize_annotate_layer` API.
which indicates that TF does not know what to do with it.
I also saw this related topic where they apply tf.contrib.quantize.create_training_graph
on a keras constructed model. They however don't use BatchNorm layers, so I'm not sure this will work.
So what are the options for using this BatchNorm folding feature in TF2? Can this be done from the keras API, or should I switch back to the TensorFlow 1 API and define a graph the old way?