PyTorch实现断点继续训练
当epochs值比较大,在训练过程由于一些原因模型停止训练,好的解决方法是(1)重新训练模型;(2)接着上次训练的断点继续训练一、模型的保存与加载1. 保存整个Moduletorch.save(net, path)2. 保存模型参数state_dict = net.state_dict()torch.save(state_dict , path)二、模型训练过程中保存checkpoint = {"
·
当epochs值比较大,在训练过程由于一些原因模型停止训练,好的解决方法是(1)重新训练模型;(2)接着上次训练的断点继续训练
一、模型的保存与加载
1. 保存整个Module
torch.save(net, path)
2. 保存模型参数
state_dict = net.state_dict()
torch.save(state_dict , path)
二、模型训练过程中保存
checkpoint = {
"net": model.state_dict(),
'optimizer':optimizer.state_dict(),
"epoch": epoch
}
if not os.path.isdir("./models/checkpoint"):
os.mkdir("./models/checkpoint")
torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
# 在训练过程中每个多少个epoch保存一次网络参数,便于恢复,提高程序的鲁棒性
三、模型断点继续训练
if RESUME:
path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
'lr_schedule': lr_schedule.state_dict()
四、epoch的恢复(重点)
start_epoch = -1
if RESUME:
path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler
for epoch in range(start_epoch + 1 ,EPOCH):
# print('EPOCH:',epoch)
optimizer.zero_grad()
optimizer.step()
lr_schedule.step()
print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
for step, (b_img,b_label) in enumerate(train_loader):
train_output = model(b_img)
loss = loss_func(train_output,b_label)
# losses.append(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
五、初始化随机数种子
import torch
import random
import numpy as np
def set_random_seed(seed = 10,deterministic=False,benchmark=False):
random.seed(seed)
np.random(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
if benchmark:
torch.backends.cudnn.benchmark = True
# benchmark模式会提升计算速度,但是由于计算机中有随机性,每次网络训练结果会存在差异。如果要避免这种结果波动,设置deterministic用来固定内部随机性。
更多推荐
已为社区贡献1条内容
所有评论(0)