apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块

我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module, 也就是模块。

pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作

from torch import nn

def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)

model = nn.Sequential(
            nn.Linear(2, 2), 
            nn.Linear(2, 2)
        )
model.apply(init_weights)

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐