pytorch apply函数
pytorch apply函数
·
apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块。
该方法会将fn递归的应用于模块的每一个子模块(.children()的结果)及其自身。典型的用法是,对一个model的参数进行初始化。
from torch import nn
import torch
@torch.no_grad() ##装饰器
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
m.bias.data.fill_(0)
model = nn.Sequential(
nn.Linear(2, 2),
)
model.apply(init_weights)
print(list(model.parameters()))
更多推荐
已为社区贡献5条内容
所有评论(0)