Pytorch apply() 函数
apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module, 也就是模块。pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_wei
·
apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块
我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module, 也就是模块。
pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作
from torch import nn def init_weights(m): print(m) if type(m) == nn.Linear: m.weight.data.fill_(1.0) m.bias.data.fill_(0) model = nn.Sequential( nn.Linear(2, 2), nn.Linear(2, 2) ) model.apply(init_weights)
更多推荐
已为社区贡献25条内容
所有评论(0)