pytorch技巧 一: 查看模型结构

1. torchviz

第一步:安装graphviz, 网上教程很多,也可以点这里。 注意记得配置环境变量。
第二步:安装torchviz,打开终端输入pip install torchviz
第三步:使用

import torch
from torchviz import make_dot

class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linearl = torch.nn.Linear(3, 5)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(5, 2)

    def forward(self, x):
        x = self.linearl(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x

model = MLP()
x = torch.randn(8, 3)
y = model(x)
vise=make_dot(y, params=dict(model.named_parameters()))
vise.view()

这是一个简单的两层感知机网络,模型查看结果以pdf保存在工程文件夹下。在这里插入图片描述

2. torchsummary

第一步: 安装torchsummary,打开终端输入pip install torchsummary
第二步: 使用 (需要使用GPU,测试过CPU会报错)

import torch
from torchsummary import summary

class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linearl = torch.nn.Linear(3, 5)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(5, 2)

    def forward(self, x):
        x = self.linearl(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x

device = torch.device("cuda" )
model = MLP().to(device)
summary(model, (8, 3))

模型查看结果在终端显示:在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐