pytorch技巧 一: 查看模型结构
pytorch技巧 一: 查看模型结构1. torchviz第一步:安装graphviz, 网上教程很多,也可以点这里。 注意记得配置环境变量。第二步:安装torchviz,打开终端输入pip install torchviz第三步:使用import torchfrom torchviz import make_dotclass MLP(torch.nn.Module):def __init__(
·
pytorch技巧 一: 查看模型结构
1. torchviz
第一步:安装graphviz
, 网上教程很多,也可以点这里。 注意记得配置环境变量。
第二步:安装torchviz
,打开终端输入pip install torchviz
第三步:使用
import torch
from torchviz import make_dot
class MLP(torch.nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.linearl = torch.nn.Linear(3, 5)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(5, 2)
def forward(self, x):
x = self.linearl(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = MLP()
x = torch.randn(8, 3)
y = model(x)
vise=make_dot(y, params=dict(model.named_parameters()))
vise.view()
这是一个简单的两层感知机网络,模型查看结果以pdf
保存在工程文件夹下。
2. torchsummary
第一步: 安装torchsummary
,打开终端输入pip install torchsummary
第二步: 使用 (需要使用GPU
,测试过CPU
会报错)
import torch
from torchsummary import summary
class MLP(torch.nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.linearl = torch.nn.Linear(3, 5)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(5, 2)
def forward(self, x):
x = self.linearl(x)
x = self.relu(x)
x = self.linear2(x)
return x
device = torch.device("cuda" )
model = MLP().to(device)
summary(model, (8, 3))
模型查看结果在终端显示:
更多推荐
已为社区贡献2条内容
所有评论(0)