How to dynamically freeze weights after compiling model in Keras?
Asked Answered
O

3

8

I would like to train a GAN in Keras. My final target is BEGAN, but I'm starting with the simplest one. Understanding how to freeze weights properly is necessary here and that's what I'm struggling with.

During the generator training time the discriminator weights might not be updated. I would like to freeze and unfreeze discriminator alternately for training generator and discriminator alternately. The problem is that setting trainable parameter to false on discriminator model or even on its' weights doesn't stop model to train (and weights to update). On the other hand when I compile the model after setting trainable to False the weights become unfreezable. I can't compile the model after each iteration because that negates the idea of whole training.

Because of that problem it seems that many Keras implementations are bugged or they work because of some non-intuitive trick in old version or something.

Ocreate answered 17/7, 2017 at 21:47 Comment(0)
B
13

I've tried this example code a couple months ago and it worked: https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py

It's not the simplest form of GAN, but as far as I remembered, it's not too difficult to remove the classification loss and turn the model into a GAN.

You don't need to turn on/off the discriminator's trainable property and recompile. Simply create and compile two model objects, one with trainable=True (discriminator in the code) and another one with trainable=False (combined in the code).

When you're updating the discriminator, call discriminator.train_on_batch(). When you're updating the generator, call combined.train_on_batch().

Basanite answered 23/7, 2017 at 13:28 Comment(0)
R
0

Can you use tf.stop_gradient to conditionally freeze weights?

Ratable answered 17/7, 2017 at 22:36 Comment(3)
tf.stop_gradient is stopping the gradient to flow and that's not what I want to achieve. I would like to make the gradient flow and compute gradients for weights, but not perform the update operation.Ocreate
Then you might be better off explicitly passing the list of variables you want to update to the tensorflow update op, instead of freezing / unfreezing weights all the time.Ratable
You are right, but it's Tensorflow solution and Keras doesn't allow to do that. You have a model abstraction and you mainly have fit and train_on_batch methods, that's all. If there will be no solution in pure Keras then I'll switch to Tensorflow.Ocreate
A
0

Maybe your adversarial net(generator plus discriminator) are wrote in 'Model'. However, even you set the d.trainable=False, the independent d net are set non-trainable, but the d in the whole adversarial net is still trainable.

You can use the d_on_g.summary() before then after set d.trainable=False and you would know What I mean(pay attention to the trainable variables).

Alburnum answered 5/11, 2017 at 14:50 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.