原因:代码中需要多次用到一个高维tensor变量,每一个batch都要更新一次它的值,这个值的获得需要数据集过一次网络
心路历程:刚开始以为是传参或者内存释放的问题,去深入研究了python高级用法,用了很多del、gc.collect()语句,发现内存释放不掉,还是随着训练过程逐渐增长
最终解决:因为是把两个代码的方法往一个整合,又仔细看了源码,发现源代码过网络的时候用了with torch.no_grad()这个语句,就试了一下,没想到就是这个问题,困了我两三天啊啊啊
分析:如果没有这个语句的话,因为要反向求导,pytorch就会创建变量保存梯度,跟这个高维tensor“体量”差不多的东西,如果每过一次网络保存一次梯度,这个内存消耗就非常大了,最终cuda memory out

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐