pytorch实现模型的保存和加载
pytorch保存模型 加载模型
·
第一种方式
# 获得模型的参数和buffer量
path = "state_dict_model.pt"
# 保存
torch.save(model.state_dict(),path)
# 加载
model = Network(input_num)
model.load_state_dict(torch.load(path))
# 将内部的training参数 设置为FALSE 这样在直接使用模型进行预测时
# 就不再继续计算梯度值
model.eval()
第二种方式
# 对整个模型进保存和加载
path = "entire_model.pt"
# 保存模型
torch.save(model, path)
# 加载模型
model = torch.load(path)
model.eval()
第三种方式
# 保存checkpoint
path = 'model.pt'
torch.save(
{
'epoch':epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict':optimizer.state_dict(),
'loss': loss_fn
},path
)
# 加载
model = Network(input_num)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=lr)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
model.eval()
# or
model.train()
更多推荐
已为社区贡献1条内容
所有评论(0)