最近在看大佬们写的代码时,看到使用add_module函数。所以就了解了一番,在这里做个介绍。add_module就好像是list的用法,看代码之后就懂了(代码来自文献【3】)

from torch import nn
from torchsummary import summary

class Net_test(nn.Module):
    def __init__(self):
        super(Net_test,self).__init__()
        self.conv_1 = nn.Conv2d(3,6,3)
        self.add_module('conv_2', nn.Conv2d(6,12,3))
        self.conv_3 = nn.Conv2d(12,24,3)
        
    def forward(self,x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        return x
    
model = Net_test()
print(model)
model.to('cuda')
summary(model,(3,128,128))

“”“
Net_test(
  (conv_1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv_2): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
  (conv_3): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 6, 126, 126]             168
            Conv2d-2         [-1, 12, 124, 124]             660
            Conv2d-3         [-1, 24, 122, 122]           2,616
================================================================
Total params: 3,444
Trainable params: 3,444
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 4.86
Params size (MB): 0.01
Estimated Total Size (MB): 5.06
“””

参考文献

[1][Pytorch进阶技巧(一)] 使用add_module替换部分模型
[2]PyTorch-网络的创建,预训练模型的加载
[3]pytorch 使用 add_module 添加模块

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐