ResNet详解
文章目录前言一、ResNet的几种结构二、ResNet-18ResNet-501. 网络结构2. 两种Block的实现前言本文介绍ResNet网络结构,以及用pytorch实现两种残差结构。文章末尾有ResNet的实现提示:以下是本篇文章正文内容,下面案例可供参考一、ResNet的几种结构ResNet分为18层的,34层的,50层的,101层的,152层的。每种都是由两种block结构堆叠而成,一
·
前言
本文介绍ResNet网络结构,以及用pytorch实现两种残差结构。文章末尾有ResNet的实现
提示:以下是本篇文章正文内容,下面案例可供参考
一、ResNet的几种结构
ResNet分为18层的,34层的,50层的,101层的,152层的。每种都是由两种block结构堆叠而成,一种是叫做BasicBlock,一种叫做BottleneckBlock。
二、ResNet-18 ResNet-50
1. 网络结构
我们以ResNet-18和ResNet-50为例,详细介绍网络结构,直接上图。
原图来源不清楚了,有谁知道的麻烦告诉我一下。
这个图讲的很清楚了。虚线代表shortcut,也就是特征图尺寸和维度不一致时想将输入与输出相加的话就需要进行下采样和升维。尺寸一致了才能相加。图中每一个弧线包住的部分就是ResNet中的block。两个卷积层的是BasicBlock,三个卷积层的是BottleneckBlock(看卷积核大小是 11 33 1*1, 两头小中间大,因此取名瓶颈块)。
2. 两种Block的实现
代码如下(示例):
import torch.nn as nn
# ResNetBlock 基类, BasicBlock和BottleneckBlock继承此类
class ResNetBlock(nn.Module):
def __init__(self, in_channels, block_channels, out_channels):
super(ResNetBlock, self).__init__()
self.in_channels = in_channels
self.block_channels = block_channels
self.out_channels = out_channels
self.out_strides = None
#
class ResNetBasicBlock(ResNetBlock):
def __init__(self, in_channels, block_channels, out_channels):
super().__init__(in_channels, block_channels, out_channels)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, block_channels, kernel_size=***, stride=***, padding=***, bias=False),
nn.BatchNorm2d(block_channels),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(block_channels, out_channels, kernel_size=***, stride=***, padding=***, bias=False),
nn.BatchNorm2d(out_channels)
)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=***, stride=***, padding=***, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.conv1(x)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu(out)
return out
class ResNetBottleneckBlock(ResNetBlock):
def __init__(self, in_channels, block_channels, out_channels):
super().__init__(in_channels, block_channels, out_channels)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, block_channels, kernel_size=***, stride=***, padding=***, bias=False),
nn.BatchNorm2d(block_channels),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(block_channels, block_channels, kernel_size=***, stride=conv2_params['s'], padding=***, bias=False),
nn.BatchNorm2d(block_channels),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(block_channels, out_channels, kernel_size=***, stride=***, padding=***, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=***, stride=***, padding=***, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu(out)
return out
更多推荐
已为社区贡献1条内容
所有评论(0)