当epochs值比较大,在训练过程由于一些原因模型停止训练,好的解决方法是(1)重新训练模型;(2)接着上次训练的断点继续训练

一、模型的保存与加载

1. 保存整个Module

torch.save(net, path)

2. 保存模型参数

state_dict = net.state_dict()
torch.save(state_dict , path)

二、模型训练过程中保存

checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    }
    if not os.path.isdir("./models/checkpoint"):
        os.mkdir("./models/checkpoint")
    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
    # 在训练过程中每个多少个epoch保存一次网络参数,便于恢复,提高程序的鲁棒性

三、模型断点继续训练

if RESUME:
    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    'lr_schedule': lr_schedule.state_dict()

四、epoch的恢复(重点)

start_epoch = -1
if RESUME:
    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler
for epoch in  range(start_epoch + 1 ,EPOCH):
    # print('EPOCH:',epoch)
    optimizer.zero_grad()
    optimizer.step()
    lr_schedule.step()
    print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
    
    for step, (b_img,b_label) in enumerate(train_loader):
        train_output = model(b_img)
        loss = loss_func(train_output,b_label)
        # losses.append(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

五、初始化随机数种子

import torch
import random
import numpy as np
def set_random_seed(seed = 10,deterministic=False,benchmark=False):
    random.seed(seed)
    np.random(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
    if benchmark:
        torch.backends.cudnn.benchmark = True
        # benchmark模式会提升计算速度,但是由于计算机中有随机性,每次网络训练结果会存在差异。如果要避免这种结果波动,设置deterministic用来固定内部随机性。
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐