目的

我们通常在构建网络时,会使用一些比较成熟的网络构建backbone,比如ResNet、MobieNet等等。但有些时候并不需要使用整个backbone,而只需要其中某些层的输出,但自己构建一边backbone又很麻烦。

本文主要介绍这种方法就可以很方便地从一个已经搭建好的网络中方便地提取到某些层的输出。

IntermediateLayerGetter方法

参考自torchvision的实现,代码与注释如下:

class IntermediateLayerGetter(nn.ModuleDict):
    """ get the output of certain layers """
    def __init__(self, model, return_layers):
    	# 判断传入的return_layers是否存在于model中
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
		
        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}	# 构造dict
        layers = OrderedDict()
        # 将要从model中获取信息的最后一层之前的模块全部复制下来
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers) # 将所需的网络层通过继承的方式保存下来
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        # 将所需的值以k,v的形式保存到out中
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

使用

使用起来非常方便,首先确定好你要返回的信息在网络中的那个module,然后构造字典,k为backbone中的module名,v为返回out中的k值。示例如下:

import torchvision
    
model = torchvision.models.resnet18()
return_layers = {'layer1':'feature_1', 'layer2':'feature_2'}
backbone = IntermediateLayerGetter(model, return_layers)

backbone.eval()
x = torch.randn(1,3,224,224)
out = backbone(x)
print(out['feature_1'].shape, out['feature_2'].shape)

输出:

torch.Size([1, 64, 56, 56]) torch.Size([1, 128, 28, 28])
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐