这是Pytorch学习之路的第五篇

遇到问题

虽然已经知道了怎么保存已经训练好的网络模型,但是还是不知道怎么调用。其他博客中讲的有点简略,还需要自己摸索一下:

PyTorch要加载已经训练好的网络模型,需要保留什么代码,增加什么代码?

解决方法(只讨论仅加载参数的方法)

导入的库都不变,且只有测试模型前代码需要做改动:

import torch.nn as nn
import torch.nn.functional as F
#以下为需要保留的代码
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        #self.aap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(1296,128)
        self.fc2 = nn.Linear(128,10)
        #self.fc3 = nn.Linear(36,10)
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        #x = self.aap(x)
        #x = x.view(x.shape[0],-1)
        #x = self.fc3(x)
        x = x.view(-1,36*6*6)
        #print("x.shape:{}".format(x.shape))
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return x

model = CNNNet()

#以下为新增代码
model.load_state_dict(torch.load('./model/model.pth'))#再加载网络的参数
model = model.to(device)
print("load success")

注意

model = torch.load('./model/model.pth')

会报错
在这里插入图片描述
原因未知。

效果

成功
在这里插入图片描述

灵感来源

  1. pytorch:无法加载CNN模型并做预测TypeError:'collections. OrderedDict’对象不可调用(转载)
  2. Pytorch文档阅读(五)如何保存、加载网络模型(转载)
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐