torch.nn.module之self._modules,module()以及children()的区别与联系
一、区别1、children() → Iterator[torch.nn.modules.module.Module],是nn.Module类的成员函数,返回当前children模块的迭代器; modules() → Iterator[torch.nn.modules.module.Module],是nn.Module类的成员函数,返回所有模块的迭代器; self.modules是nn.Mod
一、区别
1、children() → Iterator[torch.nn.modules.module.Module],是nn.Module类的成员函数,返回当前children模块的迭代器,作用:只获取模型‘儿子’,不再往深处获取’孙子;
modules() → Iterator[torch.nn.modules.module.Module],是nn.Module类的成员函数,返回所有模块的迭代器,作用:将整个模型的所有构成(包括包装层、单独的层、自定义层等)由浅入深依次遍历出来, 直到最深处的单层,只不过modules()返回的每一个元素是直接返回的层对象本身,而named_modules()返回的每一个元素是一个元组,第一个元素是名称,第二个元素才是层对象本身。主要作用是需要获取模型的每个层的对象时使用,比如模型初始化,模型加载参数等;
self.modules是nn.Module类的成员变量,字典类型,作用:用来保存组成网络的children模块。
二、联系
children()和modules()函数都是在self.modules成员变量上操作
def children(self) -> Iterator['Module']:
for name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
memo = set()
#以下开始遍历self._modules
for name, module in self._modules.items():
#如果self._modules中有重复的value,剔除
if module is not None and module not in memo:
memo.add(module)
yield name, module
从上面的代码可以看到children()返回的是self._modules中非重复的layer或者多个layer组成的小模块
def modules(self) -> Iterator['Module']:
for name, module in self.named_modules():
yield module
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
从上面的代码可以看出,modules将整个模型的所有结构(包括children模型,以及children模型中的子部件),从外到内,由浅入深一次遍历出来。
三、举例说明
l = torch.nn.Linear(2, 2)
net = torch.nn.Sequential(l, l)
print("nn.modules")
for idx, m in enumerate(net.modules()):
print(idx, '->', m)
print("nn.children()")
for idx, m in enumerate(net.children()):
print(idx, '->', m)
print("nn._modules")
print (net._modules)
for idx in net._modules.keys():
print(idx, '->', net._modules[idx])
输出结果:
nn.modules()的结果是:
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
nn.children()的结果是:
0 -> Linear(in_features=2, out_features=2, bias=True)
nn._modules的结果是:
OrderedDict([('0', Linear(in_features=2, out_features=2, bias=True)), ('1', Linear(in_features=2, out_features=2, bias=True))])
0 -> Linear(in_features=2, out_features=2, bias=True)
1 -> Linear(in_features=2, out_features=2, bias=True)
更多推荐
所有评论(0)