When using load_weights()
and save_weights()
with a nested model, it's very easy to get an error if the trainable
settings are not the same.
To solve the error, make sure you freeze the same layers before calling model.load_weights()
. That is, if the weight file is saved with all layers frozen, the procedure will be:
- Recreate the model
- Freeze all layers in
base_model
- Load the weights
- Unfreeze those layers you want to train (in this case,
base_model.layers[-26:]
)
For example,
base_model = ResNet50(include_top=False, input_shape=(224, 224, 3))
model = Sequential()
model.add(base_model)
model.add(Flatten())
model.add(Dense(80, activation="softmax"))
for layer in base_model.layers:
layer.trainable = False
model.load_weights('all_layers_freezed.h5')
for layer in base_model.layers[-26:]:
layer.trainable = True
The underlying reason:
When you call model.load_weights()
, (roughly) the weight for each layer is loaded by the following steps (in the function load_weights_from_hdf5_group()
in topology.py):
- Call
layer.weights
to get the weight tensors
- Match each weight tensor with its corresponding weight value in the hdf5 file
- Call
K.batch_set_value()
to assign the weight values to the weight tensors
If your model is a nested model, you have to be careful about trainable
because of Step 1.
I'll use an example to explain it. For the same model as above, model.summary()
gives:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Model) (None, 1, 1, 2048) 23587712
_________________________________________________________________
flatten_10 (Flatten) (None, 2048) 0
_________________________________________________________________
dense_5 (Dense) (None, 80) 163920
=================================================================
Total params: 23,751,632
Trainable params: 11,202,640
Non-trainable params: 12,548,992
_________________________________________________________________
The inner ResNet50
model is treated as a layer of model
during weight loading. When loading the layer resnet50
, in Step 1, calling layer.weights
is equivalent to calling base_model.weights
. The list of weight tensors for all layers in the ResNet50
model will be collected and returned.
Now the problem is that, when constructing the list of weight tensors, trainable weights will come before non-trainable weights. In the definition of Layer
class:
@property
def weights(self):
return self.trainable_weights + self.non_trainable_weights
If all layers in base_model
are frozen, the weight tensors will be in the following order:
for layer in base_model.layers:
layer.trainable = False
print(base_model.weights)
[<tf.Variable 'conv1/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref>,
<tf.Variable 'conv1/bias:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/gamma:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/beta:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/moving_mean:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/moving_variance:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'res2a_branch2a/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref>,
<tf.Variable 'res2a_branch2a/bias:0' shape=(64,) dtype=float32_ref>,
...
<tf.Variable 'res5c_branch2c/kernel:0' shape=(1, 1, 512, 2048) dtype=float32_ref>,
<tf.Variable 'res5c_branch2c/bias:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/gamma:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/beta:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/moving_mean:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/moving_variance:0' shape=(2048,) dtype=float32_ref>]
However, if some layers are trainable, the weight tensors of the trainable layers will come before that of the frozen ones:
for layer in base_model.layers[-5:]:
layer.trainable = True
print(base_model.weights)
[<tf.Variable 'res5c_branch2c/kernel:0' shape=(1, 1, 512, 2048) dtype=float32_ref>,
<tf.Variable 'res5c_branch2c/bias:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/gamma:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/beta:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'conv1/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref>,
<tf.Variable 'conv1/bias:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/gamma:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/beta:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/moving_mean:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bn_conv1/moving_variance:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'res2a_branch2a/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref>,
<tf.Variable 'res2a_branch2a/bias:0' shape=(64,) dtype=float32_ref>,
...
<tf.Variable 'bn5c_branch2b/moving_mean:0' shape=(512,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2b/moving_variance:0' shape=(512,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/moving_mean:0' shape=(2048,) dtype=float32_ref>,
<tf.Variable 'bn5c_branch2c/moving_variance:0' shape=(2048,) dtype=float32_ref>]
The change in order is why you got an error about tensor shapes. The weight values saved in the hdf5 file are matched to the wrong weight tensors in Step 2 mentioned above. The reason that everything works fine when you freeze all layers is because your model checkpoint is saved also with all layers frozen and thus the order is correct.
Possibly better solution:
You can avoid a nested model by using the functional API. For example, the following code should work without error:
base_model = ResNet50(include_top=False, weights="imagenet", input_shape=(input_size, input_size, input_channels))
x = Flatten()(base_model.output)
x = Dense(80, activation="softmax")(x)
model = Model(base_model.input, x)
for layer in base_model.layers:
layer.trainable = False
model.save_weights("all_nontrainable.h5")
base_model = ResNet50(include_top=False, weights="imagenet", input_shape=(input_size, input_size, input_channels))
x = Flatten()(base_model.output)
x = Dense(80, activation="softmax")(x)
model = Model(base_model.input, x)
for layer in base_model.layers[:-26]:
layer.trainable = False
model.load_weights("all_nontrainable.h5")