一、神经网络初始化

我喜欢在网络的构造函数中进行。比如

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.AvgPool2d(2, 2)
        )

        self.conv21 = nn.Conv2d(64, 64*2, 3, padding = 1 )
        self.pool2 = nn.AvgPool2d(2, 2)
        
        self.conv31 = nn.Conv2d(64*2, 10, 1)
        self.pool3 = nn.AvgPool2d(8, 8)

        self.line = nn.Linear(10,100)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight,1.0,0.02)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight,.0, 0.05)
                nn.init.zeros_(m.bias)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.conv21(x))
        x = self.pool2(x)
        x = self.conv31(x)
        x = self.pool3(x)
        x = x.view(-1, 10)
        x = self.line(x)
        return x

      上面代码中的__init__函数中的for m in self.modules(): 循环语句就是对网络结构中的结点进行赋值,那么这段代码是如何进行的呢?它是从外到内依次进行的
      第一次进入循环,m的值会是整个网络,则是

 Net(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (conv21): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv31): Conv2d(128, 10, kernel_size=(1, 1), stride=(1, 1))
  (pool3): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (line): Linear(in_features=10, out_features=100, bias=True)
)

很明显不会进入任何分支,第二次进入循环,m的值是self.conv1右边的值,即为

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
)

也很明显不会进入任何分支,第三次进入循环,m的值是上一次self.conv1的序列内部分的第一个结构,即

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

这会进入分支中的第一个if语句。接下来的循环m的值为

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

这会进入elif isinstance(m, nn.BatchNorm2d)分支结构,如此这样直到退出整个self.conv1序列,然后进入接下来的self.conv21self.pool2等结点,直到整个网络赋值完成。

二、打印w,b和梯度

我们接着使用下面的例子查看初始化的值和权重

net = Net()
print(net.conv21.bias)
print(net.conv21.bias.grad)
print(net.conv21.weight)
print(net.conv21.weight.grad)

输出结果

conv21.bias =  Parameter containing:
tensor([0., 0.省略中间的0., 0., 0., 0., 0.], requires_grad=True)
conv21.bias.grad =  None
conv21.weight =  Parameter containing:
tensor([[[[ 0.0026,  0.0542, -0.0444],省略中间的值, [-0.0391, -0.0124,  0.0965]]]],requires_grad=True)
conv21.weight.grad =  None

能看到偏置项确实为0了,权重也是我们想要的结果,但是为什么所有的偏导都为None呢?那是因为我们并没有进行反向传播,采用下面的代码进行一次反向传播后打印的值就不为None了。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

inputs = torch.rand(4,3,32,32)
labels = torch.rand(4)*10//5
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
inputs = inputs.to(device)
labels = labels.to(device)

outputs = net(inputs)

loss = criterion(outputs, labels.long())
loss.backward()
optimizer.step()

print("conv21.bias = ",net.conv21.bias)
print("conv21.bias.grad = ",net.conv21.bias.grad)
print("conv21.weight = ",net.conv21.weight)
print("conv21.weight.grad = ",net.conv21.weight.grad)

小TIP: 在调试代码的时候有时候想打印整个Tensor的值,这时我们可以使用下面的写法临时改变一下torch.set_printoptions(profile="full")将full改为default即是默认的省略输出

那么如何打印Sequential序列中的值呢

for i,m in enumerate(net.conv1.children()):
    if isinstance(m, nn.Conv2d):
        print("net.conv1."+str(i)+"(Conv2d).weight = ",m.weight)
        print("net.conv1."+str(i)+"(Conv2d).weight.grad = ",m.weight.grad)
    elif isinstance(m, nn.BatchNorm2d):
        print("net.conv1."+str(i)+"(BatchNorm2d).weight = ",m.weight)
        print("net.conv1."+str(i)+"(BatchNorm2d).weight.grad = ",m.weight.grad)

即循环遍历序列下的子项,使用isinstance进行判断目标对象并打印。
部分结果展示
在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐