前言

本文介绍ResNet网络结构,以及用pytorch实现两种残差结构。文章末尾有ResNet的实现


提示:以下是本篇文章正文内容,下面案例可供参考

一、ResNet的几种结构

ResNet分为18层的,34层的,50层的,101层的,152层的。每种都是由两种block结构堆叠而成,一种是叫做BasicBlock,一种叫做BottleneckBlock。
在这里插入图片描述

二、ResNet-18 ResNet-50

1. 网络结构

我们以ResNet-18和ResNet-50为例,详细介绍网络结构,直接上图。

原图来源不清楚了,有谁知道的麻烦告诉我一下。
这个图讲的很清楚了。虚线代表shortcut,也就是特征图尺寸和维度不一致时想将输入与输出相加的话就需要进行下采样和升维。尺寸一致了才能相加。图中每一个弧线包住的部分就是ResNet中的block。两个卷积层的是BasicBlock,三个卷积层的是BottleneckBlock(看卷积核大小是 11 33 1*1, 两头小中间大,因此取名瓶颈块)。

2. 两种Block的实现

代码如下(示例):

import torch.nn as nn
# ResNetBlock  基类, BasicBlock和BottleneckBlock继承此类
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, block_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.in_channels = in_channels
        self.block_channels = block_channels
        self.out_channels = out_channels

        self.out_strides = None

#
class ResNetBasicBlock(ResNetBlock):

    def __init__(self, in_channels, block_channels, out_channels):

        super().__init__(in_channels, block_channels, out_channels)

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, block_channels, kernel_size=***, stride=***, padding=***, bias=False),
            nn.BatchNorm2d(block_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(block_channels, out_channels, kernel_size=***, stride=***, padding=***, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = None

        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=***, stride=***, padding=***, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        if self.shortcut is not None:
            shortcut = self.shortcut(x)
        else:
            shortcut = x
        out += shortcut
        out = F.relu(out)
        return out


class ResNetBottleneckBlock(ResNetBlock):

    def __init__(self, in_channels, block_channels, out_channels):
        super().__init__(in_channels, block_channels, out_channels)

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, block_channels, kernel_size=***, stride=***, padding=***, bias=False),
            nn.BatchNorm2d(block_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(block_channels, block_channels, kernel_size=***, stride=conv2_params['s'], padding=***, bias=False),
            nn.BatchNorm2d(block_channels),
            nn.ReLU()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(block_channels, out_channels, kernel_size=***, stride=***, padding=***, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.shortcut = None

        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=***, stride=***, padding=***, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.shortcut is not None:
            shortcut = self.shortcut(x)
        else:
            shortcut = x
		
        out += shortcut
        out = F.relu(out)

        return out

我写了一个ResNet的实现

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐