How to use Exponential Moving Average in Tensorflow
Asked Answered
D

1

7

The problem

Tensorflow includes the function tf.train.ExponentialMovingAverage which allows us to apply a moving average to the parameters, which I've found to be great to stabilize the testing of the model.

With that said, I've found it somewhat irritatingly hard to apply this to general models. My so far most successful approach (shown below) has been to write a function decorator and then put my whole NN inside a function.

This does however have several downsides. For one, it duplicates the whole graph, and second, I need to define my NN inside a function.

Is there a better way to do this?

Current Implementation

def ema_wrapper(is_training, decay=0.99):
    """Use Exponential Moving Average of parameters during testing.

    Parameters
    ----------
    is_training : bool or `tf.Tensor` of type bool
        EMA is applied if ``is_training`` is False.
    decay:
        Decay rate for `tf.train.ExponentialMovingAverage`
    """
    def function(fun):
        @functools.wraps(fun)
        def fun_wrapper(*args, **kwargs):
            # Regular call
            with tf.variable_scope('ema_wrapper', reuse=False) as scope:
                result_train = fun(*args, **kwargs)

            # Set up exponential moving average
            ema = tf.train.ExponentialMovingAverage(decay=decay)
            var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope.name)
            ema_op = ema.apply(var_class)

            # Add to collection so they are updated
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

            # Getter for the variables with EMA applied
            def ema_getter(getter, name, *args, **kwargs):
                var = getter(name, *args, **kwargs)
                ema_var = ema.average(var)
                return ema_var if ema_var else var

            # Call with EMA applied
            with tf.variable_scope('ema_wrapper', reuse=True,
                                   custom_getter=ema_getter):
                result_test = fun(*args, **kwargs)

            # Return the correct version depending on if we're training or not
            return tf.cond(is_training,
                           lambda: result_train, lambda: result_test)
        return fun_wrapper
    return function

Example usage:

@ema_wrapper(is_training)
def neural_network(x):
    # If is_training is False, we will use an EMA of a instead
    a = tf.get_variable('a', [], tf.float32)
    return a * x
Disarm answered 7/3, 2018 at 9:9 Comment(2)
Not sure if you think that's a valid solution but you could have an op that copies the EMA values to the original variables and run it after the training is done.Teena
Sure that sounds valid. Is there some way to standardize it?Disarm
T
15

You can have an op that transfers the value from the EMA variables to the original ones:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Make EMA object and update interal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
    train_op = ema.apply(model_vars)

# Transfer EMA values to original variables
retrieve_ema_weights_op = tf.group(
    [tf.assign(var, ema.average(var)) for var in model_vars])

with tf.Session() as sess:
    # Do training
    while ...:
        sess.run(train_op, ...)
    # Copy EMA values to weights
    sess.run(retrieve_ema_weights_op)
    # Test model with EMA weights
    # ...

EDIT:

I made a longer version with the ability to switch between train and test mode with variable backups:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

is_training = tf.get_variable('is_training', shape=(), dtype=tf.bool,
                              initializer=tf.constant_initializer(True, dtype=tf.bool))

# Make EMA object and update internal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
    train_op = ema.apply(model_vars)
# Make backup variables
with tf.variable_scope('BackupVariables'):
    backup_vars = [tf.get_variable(var.op.name, dtype=var.value().dtype, trainable=False,
                                   initializer=var.initialized_value())
                   for var in model_vars]

def ema_to_weights():
    return tf.group(*(tf.assign(var, ema.average(var).read_value())
                     for var in model_vars))
def save_weight_backups():
    return tf.group(*(tf.assign(bck, var.read_value())
                     for var, bck in zip(model_vars, backup_vars)))
def restore_weight_backups():
    return tf.group(*(tf.assign(var, bck.read_value())
                     for var, bck in zip(model_vars, backup_vars)))

def to_training():
    with tf.control_dependencies([tf.assign(is_training, True)]):
        return restore_weight_backups()

def to_testing():
    with tf.control_dependencies([tf.assign(is_training, False)]):
        with tf.control_dependencies([save_weight_backups()]):
            return ema_to_weights()

switch_to_train_mode_op = tf.cond(is_training, lambda: tf.group(), to_training)
switch_to_test_mode_op = tf.cond(is_training, to_testing, lambda: tf.group())

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    # Unnecessary, since it begins in training mode, but unharmful
    sess.run(switch_to_train_mode_op)
    # Do training
    while ...:
        sess.run(train_op, ...)
    # To test mode
    sess.run(switch_to_test_mode_op)
    # Switching multiple times should not overwrite backups
    sess.run(switch_to_test_mode_op)
    # Test model with EMA weights
    # ...
    # Back to training mode
    sess.run(switch_to_train_mode_op)
    # Keep training...
Teena answered 7/3, 2018 at 12:33 Comment(4)
How would I think reset the model "back" to the non EMA-weights?Disarm
@JonasAdler I can think of two ways: 1) Within TensorFlow, creating another set of shadow variables for backup. 2) Out of TensorFlow, reading the variable values and storing it in Python (NumPy) objects, then putting them back with tf.assign ops or with the load method of the variables. I can extend the answer if you need help with any of those.Teena
@JonasAdler I've updated the answer with a "sketch" of option 2).Teena
@jdehesa nice answer.Choirmaster

© 2022 - 2024 — McMap. All rights reserved.