Updating Unrolled GAN to TF2
Asked Answered
H

1

6

I am trying to implement the Unrolled GAN model as described here, with example code. However, it was implemented using TF1, and I have been doing my best to update it but I am relatively new to python and TF (only been using it for the past ~6 months).

The line(s) that I cannot seem to make work (for the moment, there may be more) is this one:

gen_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

These both return empty lists, and I cannot see what I am missing. Even without specifying a scope, the get_collection() returns []. Earlier, we define both generator and discriminator as scopes like so:

def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.compat.v1.variable_scope("generator"):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
    with tf.compat.v1.variable_scope("discriminator", reuse=reuse):
        h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        log_d = slim.fully_connected(h, 1, activation_fn=None)
    return log_d

Is there a problem with the definition of the scope?

Here is my updated code in full, in case there is maybe something I missed elsewhere:

%pylab inline
from collections import OrderedDict
import tensorflow as tf
import tensorflow_probability as tfp
ds = tfp.distributions
# slim = tf.contrib.slim
import tf_slim as slim

from keras.optimizers import Adam

try:
    from moviepy.video.io.bindings import mplfig_to_npimage
    import moviepy.editor as mpy
    generate_movie = True
except:
    print("Warning: moviepy not found.")
    generate_movie = False


def remove_original_op_attributes(graph):
    """Remove _original_op attribute from all operations in a graph."""
    for op in graph.get_operations():
        op._original_op = None
        
def graph_replace(*args, **kwargs):
    """Monkey patch graph_replace so that it works with TF 1.0"""
    remove_original_op_attributes(tf.get_default_graph())
    return _graph_replace(*args, **kwargs)




def extract_update_dict(update_ops):
    """Extract variables and their new values from Assign and AssignAdd ops.
    
    Args:
        update_ops: list of Assign and AssignAdd ops, typically computed using Keras' opt.get_updates()

    Returns:
        dict mapping from variable values to their updated value
    """
    name_to_var = {v.name: v for v in tf.compat.v1.global_variables()}
    updates = OrderedDict()
    for update in update_ops:
        var_name = update.op.inputs[0].name
        var = name_to_var[var_name]
        value = update.op.inputs[1]
        if update.op.type == 'Assign':
            updates[var.value()] = value
        elif update.op.type == 'AssignAdd':
            updates[var.value()] = var + value
        else:
            raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
    return updates



def sample_mog(batch_size, n_mixture=8, std=0.01, radius=1.0):
    thetas = np.linspace(0, 2 * np.pi, n_mixture)
    xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
    cat = ds.Categorical(tf.zeros(n_mixture))
    comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())]
    data = ds.Mixture(cat, comps)
    return data.sample(batch_size)



def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.compat.v1.variable_scope("generator"):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
    with tf.compat.v1.variable_scope("discriminator", reuse=reuse):
        h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        log_d = slim.fully_connected(h, 1, activation_fn=None)
    return log_d



params = dict(
    batch_size=512,
    disc_learning_rate=1e-4,
    gen_learning_rate=1e-3,
    beta1=0.5,
    epsilon=1e-8,
    max_iter=25000,
    viz_every=5000,
    z_dim=256,
    x_dim=2,
    unrolling_steps=5,
)


tf.compat.v1.reset_default_graph()

data = sample_mog(params['batch_size'])

noise = ds.Normal(tf.zeros(params['z_dim']), 
                  tf.ones(params['z_dim'])).sample(params['batch_size'])
# Construct generator and discriminator nets
# with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=1.4)): ## old
with slim.arg_scope([slim.fully_connected], weights_initializer=tf.keras.initializers.Orthogonal(gain=1.4)):
    samples = generator(noise, output_dim=params['x_dim'])
    real_score = discriminator(data)
    fake_score = discriminator(samples, reuse=True)
    
# Saddle objective    
loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.cast(real_score, dtype=tf.float32), labels=tf.cast(tf.ones_like(real_score), dtype=tf.float32)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.cast(fake_score, dtype=tf.float32), labels=tf.cast(tf.zeros_like(fake_score), dtype=tf.float32)))

gen_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

# Vanilla discriminator update
d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
# updates = d_opt.get_updates(disc_vars, [], loss) ## old
updates = d_opt.get_updates(loss, [])
d_train_op = tf.group(*updates, name="d_train_op")

### I HAVE NOT UPDATED BEYOND THIS POINT ###

# Unroll optimization of the discrimiantor
if params['unrolling_steps'] > 0:
    # Get dictionary mapping from variables to their update value after one optimization step
    update_dict = extract_update_dict(updates)
    cur_update_dict = update_dict
    for i in xrange(params['unrolling_steps'] - 1):
        # Compute variable updates given the previous iteration's updated variable
        cur_update_dict = graph_replace(update_dict, cur_update_dict)
    # Final unrolled loss uses the parameters at the last time step
    unrolled_loss = graph_replace(loss, cur_update_dict)
else:
    unrolled_loss = loss

# Optimize the generator on the unrolled loss
g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)


sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

Hemistich answered 9/6, 2021 at 10:14 Comment(0)
M
1

The implementation of get_collection:

def get_collection(key, scope=None):
  """Wrapper for `Graph.get_collection()` using the default graph.

  See `tf.Graph.get_collection`
  for more details.

  Args:
    key: The key for the collection. For example, the `GraphKeys` class contains
      many standard names for collections.
    scope: (Optional.) If supplied, the resulting list is filtered to include
      only items whose `name` attribute matches using `re.match`. Items without
      a `name` attribute are never returned if a scope is supplied and the
      choice or `re.match` means that a `scope` without special tokens filters
      by prefix.

  Returns:
    The list of values in the collection with the given `name`, or
    an empty list if no value has been added to that collection. The
    list contains the values in the order under which they were
    collected.

  @compatibility(eager)
  Collections are not supported when eager execution is enabled.
  @end_compatibility
  """
  return get_default_graph().get_collection(key, scope)

It looks like in this code, key and scope arguments are swapped. If you provide "generator" or "discriminator" as the key with no scope i.e;

gen_vars = tf.compat.v1.get_collection("generator")
disc_vars = tf.compat.v1.get_collection("discriminator")

You should get results (I was able to reproduce locally with Tensorflow 2.2.0). The only issue I could not quite identify is, when providing scope, the function returns an empty list again, regardless of the scope value you provide. For example, tf.compat.v1.GLOBAL_VARIABLES should return everything, but that is not the case:

gen_vars = tf.compat.v1.get_default_graph().get_collection('generator', tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) # returns []
gen_vars = tf.compat.v1.get_default_graph().get_collection('generator', tf.compat.v1.GraphKeys.GLOBAL_VARIABLES) # returns []
disc_vars = tf.compat.v1.get_collection('generator') # returns a list of tensors

Update

It looks like even creating the variables in the context manager doesn't add them to the graph collection. I had to call tf.compat.v1.add_to_collection('generator', x) and tf.compat.v1.add_to_collection('discriminator', log_d) in the respective functions to get those results.

Update #2

I searched around and it doesn't appear there's a context manager which enables you to add variables declared within it to a Tensorflow collection. For the sake of completeness of this answer though, I have implemented one:

from contextlib import contextmanager

@contextmanager
def collection_scope(collection_name):
    import inspect
    from tensorflow.python.framework.ops import EagerTensor
    collection = tf.compat.v1.get_collection_ref(collection_name)
    yield
    # this is a bit of a hack, but it works...
    f = inspect.currentframe().f_back.f_back
    # only take variables which were declared within the context manager
    tf_variables = set([val.ref() for val in f.f_locals.values() if isinstance(val, EagerTensor)]) - \
                   set([val.ref() for val in f.f_back.f_locals.values() if isinstance(val, EagerTensor)])
    collection.extend(tf_variables)

You can then drop this in your functions in place of the variable scope (tf.compat.v1.variable_scope) context managers. For example, instead of:

def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.compat.v1.variable_scope('generator'):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

Do the following:

def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with collection_scope('generator'):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

With this change, all tensors declared within the scope of the context manager will be added to the collection "generator" - tf.compat.v1.get_collection('generator') will return the correct list of tensors.

Metaphor answered 16/6, 2021 at 19:32 Comment(4)
When I use just tf.compat.v1.get_collection("generator"), I still get an empty list. Do you think this is an issue with that line, or with my previous definition of "generator"? (same with discriminator). I'm using TF 2.4.1, if that mattersHemistich
Added an update to my answer. I will investigate a bit further.Metaphor
@Hemistich I have added another update with a solution.Metaphor
Huzzah, it works! get_collection() now properly returns variables within the appropriate scope! Now to get the rest of this to work...Hemistich

© 2022 - 2024 — McMap. All rights reserved.