torch.summary是pytorch的一个包,可以打印模型的每一层组成,参数量,总的参数量。

官方网址

使用summary函数,首先安装torchsummary包:

pip install torchsummary

然后导入包:

from torchsummary import summary

接着在运行时进行计算

summary(model, input_size, batch_size, device)

遇到的一些bug:

1:出现这种错误是因为输入维度高,即不需要指定batchsize,默认为-1.

 2:出现这种情况是因为输入格式错误,当有多个输入要用 [ ],将输入括起来。在官方文档中有案例。

注意:当多个输入没有报错,他运行的的结果也是错误的。

他的inputsize和total非常大,需要对torchsummary源码进行修改,

原始未修改的:

原始的未修改

 修改后的:

 修改过程参考:修改流程

首先:加上如下代码找到源码位置

import torchsummary
print(torchsummary.__file__)

然后:根据目录一步一步寻找,在_init_.py同级目录下的torchsummary.py文件中

/home/xh/.local/lib/python3.7/site-packages/torchsummary/__init__.py

最后:将第一百行:

total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))

替换为:

total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))

修改完后注意保存。

 3:当出现这种情况是因为没有指定device,加上device='cpu'就是正确的。

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐