目的

我们通常在构建网络时,会使用一些比较成熟的网络构建backbone,比如ResNet、MobieNet等等。但有些时候并不需要使用整个backbone,而只需要其中某些层的输出,但自己构建一边backbone又很麻烦。

本文主要介绍这种方法就可以很方便地从一个已经搭建好的网络中方便地提取到某些层的输出。

IntermediateLayerGetter方法

参考自torchvision的实现,代码与注释如下:

class IntermediateLayerGetter(nn.ModuleDict):
    """ get the output of certain layers """
    def __init__(self, model, return_layers):
    	# 判断传入的return_layers是否存在于model中
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
		
        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}	# 构造dict
        layers = OrderedDict()
        # 将要从model中获取信息的最后一层之前的模块全部复制下来
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers) # 将所需的网络层通过继承的方式保存下来
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        # 将所需的值以k,v的形式保存到out中
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

使用

使用起来非常方便,首先确定好你要返回的信息在网络中的那个module,然后构造字典,k为backbone中的module名,v为返回out中的k值。示例如下:

import torchvision
    
model = torchvision.models.resnet18()
return_layers = {'layer1':'feature_1', 'layer2':'feature_2'}
backbone = IntermediateLayerGetter(model, return_layers)

backbone.eval()
x = torch.randn(1,3,224,224)
out = backbone(x)
print(out['feature_1'].shape, out['feature_2'].shape)

输出:

torch.Size([1, 64, 56, 56]) torch.Size([1, 128, 28, 28])
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐