1. 理解AdamW

    1. 我们先弄清楚什么是weight decay

    2. 其实是在损失函数求导后,放在正则项前面的系数,比如L2正则,我们看一下weight decay的位置

    3. 我 们 可 以 认 为 λ 就 是 w e i g h t   d e c a y min ⁡ w L 2 ( w ) = min ⁡ w f ( w ) + λ 2 n ∑ i = 1 n w i 2 L 2 ′ ( w ) = f ′ ( w ) + λ n ∑ i = 1 n w i 我们可以认为\lambda就是weight\ decay\\ \min_wL_2(w)=\min_wf(w)+\frac{\lambda}{2n}\sum_{i=1}^nw_i^2\\ L_2^{'}(w)=f^{'}(w)+\frac{\lambda}{n}\sum_{i=1}^nw_i λweight decaywminL2(w)=wminf(w)+2nλi=1nwi2L2(w)=f(w)+nλi=1nwi

      1. AdamW是在Adam+L2正则化的基础上进行改进的算法。使用Adam优化带L2正则的损失并不有效。如果引入L2正则项,在计算梯度的时候会加上对正则项求梯度的结果。
      2. 那么如果本身比较大的一些权重对应的梯度也会比较大,由于Adam计算步骤中减去项会除以梯度平方的累积开根号,使得减去项偏小。按常理说,越大的权重应该惩罚越大,但是在Adam并不是这样。分子分母相互抵消掉了。
      3. 而权重衰减对所有的权重都采用相同的系数进行更新,越大的权重显然惩罚越大。
      4. 在常见的深度学习库中只提供了L2正则,并没有提供权重衰减的实现。
      5. paper地址

在这里插入图片描述

Adam+L2 VS AdamW

图片中红色是传统的Adam+L2 regularization的方式,绿色是Adam + weight decay的方式。可以看出两个方法的区别仅在于"系数乘以上一步参数值"(这一项实际上就是权重乘以L2项的导数,因为 x 2 x^2 x2的导数是本身x。)这一项的位置。

再结合代码来看一下AdamW的具体实现。

以下代码来自https://github.com/macanv/BERT-BiLSTM-CRF-NER/blob/master/bert_base/bert/optimization.py中的AdamWeightDecayOptimizer中的apply_gradients函数中,BERT中的优化器就是使用这个方法。

在代码中也做了一些注释用于对应之前给出的Adam简化版公式,方便理解。可以看出update += self.weight_decay_rate * param这一句是Adam中没有的,也就是Adam中绿色的部分对应的代码,weightdecay这一步是是发生在Adam中需要被更新的参数update计算之后,并且在乘以学习率learning_rate之前,这和图片中的伪代码的计算顺序是完全一致的。总之一句话,如果使用了weightdecay就不必再使用L2正则化了。

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """See base class."""
    assignments = []
    for (grad, param) in grads_and_vars:
      if grad is None or param is None:
        continue

      param_name = self._get_variable_name(param.name)

      m = tf.get_variable(
          name=param_name + "/adam_m",
          shape=param.shape.as_list(),
          dtype=tf.float32,
          trainable=False,
          initializer=tf.zeros_initializer())
      v = tf.get_variable(
          name=param_name + "/adam_v",
          shape=param.shape.as_list(),
          dtype=tf.float32,
          trainable=False,
          initializer=tf.zeros_initializer())

      # Standard Adam update.
      next_m = (
          tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
      next_v = (
          tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
                                                    tf.square(grad)))

      update = next_m / (tf.sqrt(next_v) + self.epsilon)

      # Just adding the square of the weights to the loss function is *not*
      # the correct way of using L2 regularization/weight decay with Adam,
      # since that will interact with the m and v parameters in strange ways.
      #
      # Instead we want ot decay the weights in a manner that doesn't interact
      # with the m/v parameters. This is equivalent to adding the square
      # of the weights to the loss with plain (non-momentum) SGD.
      if self._do_use_weight_decay(param_name):
        update += self.weight_decay_rate * param

      update_with_lr = self.learning_rate * update

      next_param = param - update_with_lr

      assignments.extend(
          [param.assign(next_param),
           m.assign(next_m),
           v.assign(next_v)])
    return tf.group(*assignments, name=name)

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐