The problem
Tensorflow includes the function tf.train.ExponentialMovingAverage
which allows us to apply a moving average to the parameters, which I've found to be great to stabilize the testing of the model.
With that said, I've found it somewhat irritatingly hard to apply this to general models. My so far most successful approach (shown below) has been to write a function decorator and then put my whole NN inside a function.
This does however have several downsides. For one, it duplicates the whole graph, and second, I need to define my NN inside a function.
Is there a better way to do this?
Current Implementation
def ema_wrapper(is_training, decay=0.99):
"""Use Exponential Moving Average of parameters during testing.
Parameters
----------
is_training : bool or `tf.Tensor` of type bool
EMA is applied if ``is_training`` is False.
decay:
Decay rate for `tf.train.ExponentialMovingAverage`
"""
def function(fun):
@functools.wraps(fun)
def fun_wrapper(*args, **kwargs):
# Regular call
with tf.variable_scope('ema_wrapper', reuse=False) as scope:
result_train = fun(*args, **kwargs)
# Set up exponential moving average
ema = tf.train.ExponentialMovingAverage(decay=decay)
var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope.name)
ema_op = ema.apply(var_class)
# Add to collection so they are updated
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)
# Getter for the variables with EMA applied
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var
# Call with EMA applied
with tf.variable_scope('ema_wrapper', reuse=True,
custom_getter=ema_getter):
result_test = fun(*args, **kwargs)
# Return the correct version depending on if we're training or not
return tf.cond(is_training,
lambda: result_train, lambda: result_test)
return fun_wrapper
return function
Example usage:
@ema_wrapper(is_training)
def neural_network(x):
# If is_training is False, we will use an EMA of a instead
a = tf.get_variable('a', [], tf.float32)
return a * x