PyTorch——L2范数正则化(权重衰减)
参考资料https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.12_weight-decay权重衰减(weight decay)权重衰减等价于L2范数正则化(regularization)。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。手动实现def l2_penalty(w):re
·
参考资料
- https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.12_weight-decay
权重衰减(weight decay)
权重衰减等价于L2范数正则化(regularization)。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。
手动实现
def l2_penalty(w):
return (w**2).sum() / 2
在计算loss时加上L2正则化项即可:
l = loss(y_hat, y) + lambd * l2_penalty(w)
简洁实现
直接在定义优化器时加上权重衰减参数即可:
optimizer = torch.optim.SGD(params=net.parameters(), lr=lr, weight_decay=wd)
上述方式对所有参数都进行了权重衰减,如果只想针对某些参数,分别为它们构造一个优化器实例即可:
optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减
optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr) # 不对偏差参数衰减
更多推荐
已为社区贡献3条内容
所有评论(0)