apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块。

该方法会将fn递归的应用于模块的每一个子模块(.children()的结果)及其自身。典型的用法是,对一个model的参数进行初始化。

from torch import nn
import torch
@torch.no_grad()  ##装饰器
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)


model = nn.Sequential(
    nn.Linear(2, 2),
)
model.apply(init_weights)
print(list(model.parameters()))

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐