What is the proper way to weight decay for Adam Optimizer
Asked Answered
I

2

27

Since Adam Optimizer keeps an pair of running averages like mean/variance for the gradients, I wonder how it should properly handle weight decay. I have seen two ways of implementing it.

  1. Only update mean/variance from the gradients based on the objective loss, decay weight explicitly at each mini-batch. (the following code is taken from https://github.com/dmlc/mxnet/blob/v0.7.0/python/mxnet/optimizer.py)

    weight[:] -= lr*mean/(sqrt(variance) + self.epsilon)
    
    wd = self._get_wd(index)
    if wd > 0.:
        weight[:] -= (lr * wd) * weight
    
  2. Update mean/variance from the gradients based on the objective loss + regularization loss, and update weights like usual. (the following code is taken from https://github.com/dmlc/mxnet/blob/master/src/operator/optimizer_op-inl.h#L210)

    grad = scalar<DType>(param.rescale_grad) * grad +
    scalar<DType>(param.wd) * weight;
    // stuff
    Assign(out, req[0],
       weight -
       scalar<DType>(param.lr) * mean /
       (F<square_root>(var) + scalar<DType>(param.epsilon)));
    

These two approaches sometimes show significant difference in training results. And I actually think the first one makes more sense (and find it gives better results time to time). Caffe and old version of mxnet follow the first approach, while torch, tensorflow and new version of mxnet follow the second one.

Really appreciate your help!

Indiscrimination answered 9/6, 2017 at 8:8 Comment(2)
Note the difference between the two is huge for low-bit width training, guess weight regularization becomes hurtful in that case. (and this may also apply to other similar cases)Indiscrimination
Are you sure tensorflow support weight decay of their AdamOptimizer? I just checked the code, and didn't see anything about weight decay. github.com/tensorflow/tensorflow/blob/…Formerly
M
27

Edit: see also this PR which just got merged into TF.

When using pure SGD (without momentum) as an optimizer, weight decay is the same thing as adding a L2-regularization term to the loss. When using any other optimizer, this is not true.

Weight decay (don't know how to TeX here, so excuse my pseudo-notation):

w[t+1] = w[t] - learning_rate * dw - weight_decay * w

L2-regularization:

loss = actual_loss + lambda * 1/2 sum(||w||_2 for w in network_params)

Computing the gradient of the extra term in L2-regularization gives lambda * w and thus inserting it into the SGD update equation

dloss_dw = dactual_loss_dw + lambda * w
w[t+1] = w[t] - learning_rate * dw

gives the same as weight decay, but mixes lambda with the learning_rate. Any other optimizer, even SGD with momentum, gives a different update rule for weight decay as for L2-regularization! See the paper Fixing weight decay in Adam for more details. (Edit: AFAIK, this 1987 Hinton paper introduced "weight decay", literally as "each time the weights are updated, their magnitude is also decremented by 0.4%" at page 10)

That being said, there doesn't seem to be support for "proper" weight decay in TensorFlow yet. There are a few issues discussing it, specifically because of above paper.

One possible way to implement it is by writing an op that does the decay step manually after every optimizer step. A different way, which is what I'm currently doing, is using an additional SGD optimizer just for the weight decay, and "attaching" it to your train_op. Both of these are just crude work-arounds, though. My current code:

# In the network definition:
with arg_scope([layers.conv2d, layers.dense],
               weights_regularizer=layers.l2_regularizer(weight_decay)):
    # define the network.

loss = # compute the actual loss of your problem.
train_op = optimizer.minimize(loss, global_step=global_step)
if args.weight_decay not in (None, 0):
    with tf.control_dependencies([train_op]):
        sgd = tf.train.GradientDescentOptimizer(learning_rate=1.0)
        train_op = sgd.minimize(tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))

This somewhat makes use of TensorFlow's provided bookkeeping. Note that the arg_scope takes care of appending an L2-regularization term for every layer to the REGULARIZATION_LOSSES graph-key, which I then all sum up and optimize using SGD which, as shown above, corresponds to actual weight-decay.

Hope that helps, and if anyone gets a nicer code snippet for this, or TensorFlow implements it better (i.e. in the optimizers), please share.

Mischance answered 10/6, 2018 at 5:35 Comment(6)
I completely agree with you that adding momentum or using adaptive optimizers means that the effective weight-decay term is different than pure SGD. However, my understanding of the term weight-decay has always (seemingly incorrectly) been that it is simply a practitioner's name for L2-regularization because when you implement L2 for SGD, it looks like you're just exponentially decaying the weights. It seems like TF and MXNet's implementation match my understanding as well. But as you pointed out, weight-decay seems to be its own regularization technique.Reprehensible
I have to add that based on the 1988 paper on comparing network biases (a.k.a regularizers), the weight-decay is considered an "ad-hoc" way of improving training generalization and it is shown in that paper to be equivalent to quadratic bias, a.k.a L2 regularization with pure SGD. Given that the adaptive optimizers are a recent invention, could one perhaps argue that weight-decay is indeed L2 regularization and the implementation in MXNet and TF are the correct implementation?Reprehensible
In that case, perhaps we need a different name for what's proposed in AdamW paper :)Reprehensible
I disagree, we already have two names (weight decay, and L2-regularization) for two different techniques which coincide only in one special case. Unfortunately, many academics have mixed them up. We can go all the way back to Hinton's 1987 paper which AFAIK introduced weight decay, literally as "each time the weights are updated, their magnitude is also decremented by 0.4%" (page 10).Mischance
I convinced myself that you're right. Have you submitted an issue on TF? There is this issue on MXNet.Reprehensible
Thanks for agreeing :) Last time I looked, I saw a few issues about it in the TF repos already, but now looking again, it seems that this PR just got merged two days ago! I think it's an unnecessarily complex one, but eh, that's the trend in TF, so be it.Mischance
E
5

I came across the same question. I think this code that I got from here will work for you. It implements the weight decay adam optimizer by inheritance from the tf.train.Optimizer. This is the cleanest solution I have found:

class AdamWeightDecayOptimizer(tf.train.Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""

def __init__(self,
             learning_rate,
             weight_decay_rate=0.0,
             beta_1=0.9,
             beta_2=0.999,
             epsilon=1e-6,
             exclude_from_weight_decay=None,
             name="AdamWeightDecayOptimizer"):
  """Constructs a AdamWeightDecayOptimizer."""
  super(AdamWeightDecayOptimizer, self).__init__(False, name)

  self.learning_rate = learning_rate
  self.weight_decay_rate = weight_decay_rate
  self.beta_1 = beta_1
  self.beta_2 = beta_2
  self.epsilon = epsilon
  self.exclude_from_weight_decay = exclude_from_weight_decay

def apply_gradients(self, grads_and_vars, global_step=None, name=None):
  """See base class."""
  assignments = []
  for (grad, param) in grads_and_vars:
    if grad is None or param is None:
      continue

    param_name = self._get_variable_name(param.name)

    m = tf.get_variable(
        name=param_name + "/adam_m",
        shape=param.shape.as_list(),
        dtype=tf.float32,
        trainable=False,
        initializer=tf.zeros_initializer())
    v = tf.get_variable(
        name=param_name + "/adam_v",
        shape=param.shape.as_list(),
        dtype=tf.float32,
        trainable=False,
        initializer=tf.zeros_initializer())

    # Standard Adam update.
    next_m = (
        tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
    next_v = (
        tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
                                                  tf.square(grad)))

    update = next_m / (tf.sqrt(next_v) + self.epsilon)

    # Just adding the square of the weights to the loss function is *not*
    # the correct way of using L2 regularization/weight decay with Adam,
    # since that will interact with the m and v parameters in strange ways.
    #
    # Instead we want ot decay the weights in a manner that doesn't interact
    # with the m/v parameters. This is equivalent to adding the square
    # of the weights to the loss with plain (non-momentum) SGD.
    if self._do_use_weight_decay(param_name):
      update += self.weight_decay_rate * param

    update_with_lr = self.learning_rate * update

    next_param = param - update_with_lr

    assignments.extend(
        [param.assign(next_param),
         m.assign(next_m),
         v.assign(next_v)])
  return tf.group(*assignments, name=name)

def _do_use_weight_decay(self, param_name):
  """Whether to use L2 weight decay for `param_name`."""
  if not self.weight_decay_rate:
    return False
  if self.exclude_from_weight_decay:
    for r in self.exclude_from_weight_decay:
      if re.search(r, param_name) is not None:
        return False
  return True

def _get_variable_name(self, param_name):
  """Get the variable name from the tensor name."""
  m = re.match("^(.*):\\d+$", param_name)
  if m is not None:
    param_name = m.group(1)
  return param_name

And you can use it in the following way (I have made some changes to make it useful in a more general context), This function will return a train_op that can be used in the Session:

def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps):
  """Creates an optimizer training op."""
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

  # Implements linear decay of the learning rate.
  learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

  # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  # learning rate will be `global_step/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

  # It is recommended that you use this optimizer for fine tuning, since this
  # is how the model was trained (note that the Adam m/v variables are NOT
  # loaded from init_checkpoint.)
  optimizer = AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6)


  tvars = tf.trainable_variables()
  grads = tf.gradients(loss, tvars)

  # You can do clip gradients if you need in this step(in general it is not neccessary)
  # (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

  train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

  # Normally the global step update is done inside of `apply_gradients`.
  # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
  # a different optimizer, you should probably take this line out.
  new_global_step = global_step + 1
  train_op = tf.group(train_op, [global_step.assign(new_global_step)])
  return train_op
Erne answered 13/2, 2019 at 23:49 Comment(2)
cleanest implementation I have seen!Boudicca
Note this is compatible only with TF1 (e.g. the tf.get_variable() method). I would suggest updating the code to TF2 or use TensorFlow-Addons (tfa) which has it implemented as tfa.optimizers.AdamW.Articulation

© 2022 - 2024 — McMap. All rights reserved.