import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearFC(nn.Module):

    def __init__(self):
        super(DropoutFC, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, input):
        out = self.fc(input)
        return out

Net = LinearFC()
x = torch.randint(10, (2, 3)).float()  # 随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据
Net.train()
output = Net(x)
print(output)

# train the Net

创建了一个最简单的LinearFC模型,里面有一个线性函数nn.Linear(3, 2),线性变换公式为:y=xWT+by=x W^T + by=xWT+b

通过Debug,一步一步查看运行情况:

在这里插入图片描述

当前这一步可以看到模型给我们随机初始化了权重W2×3W_{2 \times 3}W2×3和偏置b2×3b_{2 \times 3}b2×3,为什么权重WWW的shape是2×32\times32×3,因为公式里需要转置。

xxx随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据。
在这里插入图片描述
可以看出使用模型算出来的output,与手动使用公式算出来的结果一致。
在这里插入图片描述

Net.train()的作用

当网络中有 dropout,Batch Normalization 的时候。训练的要记得 Net.train(), 测试 要记得 Net.eval()。

在训练模型时会在前面加上:

Net.train()

在测试模型时在前面使用:

model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐