参考资料

  1. https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.12_weight-decay

权重衰减(weight decay)

权重衰减等价于L2范数正则化(regularization)。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。

手动实现

def l2_penalty(w):
    return (w**2).sum() / 2

计算loss时加上L2正则化项即可:

l = loss(y_hat, y) + lambd * l2_penalty(w)

简洁实现

直接在定义优化器时加上权重衰减参数即可:

optimizer = torch.optim.SGD(params=net.parameters(), lr=lr, weight_decay=wd)

上述方式对所有参数都进行了权重衰减,如果只想针对某些参数,分别为它们构造一个优化器实例即可:

optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减
    optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)  # 不对偏差参数衰减
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐