关于pytorch直接加载resnet50模型及模型参数
1.由于与resnet50的分类数不一样,所以在调用时,要使用num_classes=分类数model = torchvision.models.resnet50(pretrained=True,num_classes=5000)#pretrained=True 既要加载网络模型结构,又要加载模型参数如果需要加载模型本身的参数,需要使用pretrained=True2.由于最后一层的分类数不一样,
·
1.由于与resnet50的分类数不一样,所以在调用时,要使用num_classes=分类数
model = torchvision.models.resnet50(pretrained=True,num_classes=5000) #pretrained=True 既要加载网络模型结构,又要加载模型参数
如果需要加载模型本身的参数,需要使用pretrained=True
2.由于最后一层的分类数不一样,所以最后一层的参数数目也就不一样,所以在加载模型参数时要去掉最后一层
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
for k in list(state_dict.keys()): #固定遍历对象
print(k)
if k == "fc.weight" or k == "fc.bias":
state_dict.pop(k) #删除最后一层的模型参数
model.load_state_dict(state_dict,strict=False) #非严格加载模型参数
return model
由于字典中的元素是不固定的,所以在遍历的时候需要使用list,将其变为列表,这样元素位置就固定了,才可以进行后面的pop操作。
由于没有加载最后一层,所以参数中需要加上strict=False
3.总结一下如何调用pytorch框架中已有的模型及其参数(如果是分类器,且最后一层分类数不一样)
a.实例化model
b.点击resnet50,到源文件中去修改去除最后一层参数
更多推荐
已为社区贡献2条内容
所有评论(0)