Pytorch:提取网络中某些层的输出
目的我们通常在构建网络时,会使用一些比较成熟的网络构建backbone,比如ResNet、MobieNet等等。但有些时候并不需要使用整个backbone,而只需要其中某些层的输出,但自己构建一边backbone又很麻烦。本文主要介绍这种方法就可以很方便地从一个已经搭建好的网络中方便地提取到某些层的输出。IntermediateLayerGetter方法参考自torchvision的实现,代码与注
·
目的
我们通常在构建网络时,会使用一些比较成熟的网络构建backbone,比如ResNet、MobieNet等等。但有些时候并不需要使用整个backbone,而只需要其中某些层的输出,但自己构建一边backbone又很麻烦。
本文主要介绍这种方法就可以很方便地从一个已经搭建好的网络中方便地提取到某些层的输出。
IntermediateLayerGetter方法
参考自torchvision的实现,代码与注释如下:
class IntermediateLayerGetter(nn.ModuleDict):
""" get the output of certain layers """
def __init__(self, model, return_layers):
# 判断传入的return_layers是否存在于model中
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()} # 构造dict
layers = OrderedDict()
# 将要从model中获取信息的最后一层之前的模块全部复制下来
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers) # 将所需的网络层通过继承的方式保存下来
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
# 将所需的值以k,v的形式保存到out中
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
使用
使用起来非常方便,首先确定好你要返回的信息在网络中的那个module,然后构造字典,k为backbone中的module名,v为返回out中的k值。示例如下:
import torchvision
model = torchvision.models.resnet18()
return_layers = {'layer1':'feature_1', 'layer2':'feature_2'}
backbone = IntermediateLayerGetter(model, return_layers)
backbone.eval()
x = torch.randn(1,3,224,224)
out = backbone(x)
print(out['feature_1'].shape, out['feature_2'].shape)
输出:
torch.Size([1, 64, 56, 56]) torch.Size([1, 128, 28, 28])
更多推荐
已为社区贡献5条内容
所有评论(0)