Pytorch在加载的模型基础上继续训练
深度学习网络模型的训练往往会花费挺长时间,这时候万一断电了,机器死机了,那真的气不打一处来,想砸机器的冲动有没有?您先别着急,一般咱们的模型都写有模型参数保存功能,比如这样:if epoch%10 == 1:torch.save(model.state_dict(),'{}/moilenetV2_{}_{}.pth.format('./models',epoch,acc))我们只需要找到这个模型保
深度学习网络模型的训练往往会花费挺长时间,这时候万一断电了,机器死机了,那真的气不打一处来,想砸机器的冲动都来了有没有?
不过也不用太着急,一般咱们的模型都写有模型参数保存功能,比如这样:
if epoch%10 == 1:
torch.save(model.state_dict(),'{}/moilenetV2_{}_{}.pth.format('./models',epoch,acc))
我们只需要找到这个模型保存的位置,然后把最新的这个模型参数加载到我们的model中,就可以接着这个参数进行训练了。要加载的代码一般放在model定义之后(就是确定model的结构了),模型进行训练之前。要加载代码如下:
Resume = True
# Resume = False
if Resume:
path_checkpoint = 'your/new/model/path.pth'
checkpoint = torch.load(path_checkpoint, map_location = torch.device('cpu'))
model.load_state_dict(checkpoint)
变量Resume可以作为开关,如果想在训练好的模型基础上进行finetune(微调)的话,就把它设置为True,从零训练的话就设置为False。当然咱们这种出问题,接着训练的就设置为True就行。
知识扩充
训练模型的保存包括两种:
1、保存整个模型框架以及模型参数(存储文件过大,不推荐)
torch.save(model,path)
2、仅仅保存模型的参数文件(推荐)
torch.save(model.state_dict(),path)
"state_dict"表示state dictionary,即字典类型的参数,模型本身的参数。
其中torch.load()函数可以加载模型参数,为了保证GPU显存够用,推荐令map_location = torch.device(‘cpu’)
假如你就想加载到gpu中,可以令map_location = torch.device(‘cuda’)
最后用model.load_state_dict(checkpoint)把参数加载完成。
好了,快去训练你的模型吧!有问题欢迎留言~
更多推荐
所有评论(0)