作者:hzwer
链接:https://www.zhihu.com/question/375794498/answer/2292320194
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

这也是个困扰了我多年的问题:

loss = a * loss1 + b * loss2 + c * loss3 怎么设置 a,b,c?

我的经验是 loss 的尺度一般不太影响性能,除非本来主 loss 是 loss1,但是因为 b,c 设置太大了导致其他 loss 变成了主 loss。

实践上有几个调整方法:

  1. 手动把所有 loss 放缩到差不多的尺度,设 a = 1,b 和 c 取 10^k,k 选完不管了;
  2. 如果有两项 loss,可以 loss = a * loss1 + (1 - a) * loss2,通过控制一个超参数 a 调整 loss;
  3. 我试过的玄学躺平做法 loss = loss1 / loss1.detach() + loss2 / loss2.detach() + loss3 loss3.detach(),分母可能需要加 eps,相当于在每一个 iteration 选定超参数 a, b, c,使得多个 loss 尺度完全一致;进一步更科学一点就 loss = loss1 + loss2 / (loss2 / loss1).detach() + loss3 / (loss3 / loss1).detach(),感觉比 loss 向 1 对齐合理

可以根据自己训练的情况调整三个loss的权重,谁高了可以加大一些权重,意思就是如果某个分支loss高了,那么网络的注意力都会去这个高loss的分支去,从而对其他支路的Loss没有贡献。这里说的“增大权重”就是将loss的量级减少,最好是三个loss都在一个量级为好

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐