pytorch获取全部权重参数、每一层权重参数

首先需要安装torchsummary
在相应的虚拟环境下pip install torchsummary

1、打印每层参数信息:
summary(net,input_size,batch_size,device),

net:网络模型
input_size:网络输入图片的shape
batch_size:默认参数为-1
device:在gpu上还是cpu上运行,默认是cuda在gpu上运行,若想在cpu上运行,需将参数改为cpu。

eg.vgg16网络
from models import VGG16_torch
model = vgg16()
summary(model,(3,32,32),device=‘cpu’)
在这里插入图片描述
2、根据需要,输出相应层的权重
首先查看每层对应的名称

model = vgg16()
for name in model.state_dict():
  print(name)

在这里插入图片描述
再根据名称输出相应层的权重

 print(model.state_dict()['layers.0.conv2d.weight'])

在这里插入图片描述
3、打印模块名字和参数大小

for name, parameters in model.named_parameters():  
    print(name, ';', parameters.size())

输出结果:
在这里插入图片描述
4、加载模型全部参数

import torch
y = torch.load('vgg16_baseline.t7')
print(y)

在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐