Swin Transformer详解
Swin Transformer详解一、整体架构二、拆解Swin Transformer1. Patch Partition & Linear Embedding2. Swin Transformer Block(1)第一个BlockW-MSA模块window partitionwindow attentionMLP(2)第二个BlockSW-MSA模块3. Patch Merging三、
Swin Transformer详解
论文地址:https://arxiv.org/pdf/2103.14030.pdf
代码地址:https://github.com/microsoft/Swin-Transformer
本文一共分为三个部分,首先介绍Swin Transformer的整体架构,随后会介绍每个模块的作用,中间会穿插部分代码。本文的主要目的还是希望能够将Swin Transformer解释清楚,然后结合官方代码来理解
一、Overall Architecture
首先给出论文中的Swin Transformer架构图
左边是Swin Transformer的全局架构,它包含Patch Partition、Linear Embedding、Swin Transformer Block、Patch Merging四大部分,这四大部分我们之后会进行详细的介绍
右边是Swin Transformer Block结构图,这是两个连续的Swin Transformer Block块,一个是W-MSA,一个是SW-MSA,也就是说根据Swin的Tiny版本,图中的Swin Transformer Block块为[2, 2, 6, 2],相对应的attention为:stage1 W-MSA-->SW-MSA
– stage2 W-MSA-->SW-MSA
– stage3 W-MSA-->SW-MSA-->W-MSA-->SW-MSA-->W-MSA-->SW-MSA
– stage4 W-MSA-->SW-MSA
二、Swin Transformer
下面的维度等均是基于Swin-T
版本
1. Patch Partition & Linear Embedding
输入为(B, 3, 224, 224)
输出为(B, 96, 56, 56) —> (B, 96, 224/4=56, 224/4=56)
这两步在论文中其实就是一步实现,我们先来看paper中的解释:
- Patch Partition,这一步是将输入的(H, W, 3)的图片分成(4, 4)的小块,分块后的图片大小为(H/4, W/4, 48)也就是文中所给的维度
- Linear Embedding,在Tiny版本中,将分块后的图像映射到96维
在真正实现的时候paper使用了PatchEmbed函数将这两步结合起来,实际上也就是用了一个卷积的操作,卷积核大小为(4, 4),步长为4:nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
下面图示此过程
2. Basic Layer
在官方的代码库中,将Swin Transformer Block和Patch Merging合并成了一个,叫做Basic Layer,下面我们分别介绍这两者
Swin Transformer Block
输入为(B, 3136, 96)
输出为(B, 3136, 96)
就是把上一步的(4, 96, 56, 56)后两维度合并变为(4, 96, 3136),然后后两维互换(4, 3136, 96)
Swin Transformer Block的输入输出不变,每两个连续Block为一组,即一个Window Multi-head Self-Attention和一个Shifted Window Multi-head Self-Attention
下面是paper中的Swin Transformer Block
示例图
从图中我们可以看出每两个连续Block块有四小步:
1. 第一个Block
- 首先特征图经过
Layer Norm
层,经过W-MSA
,然后进行跳跃连接 - 连接后的特征图再次经过
Layer Norm
层,经过全连接层MLP
,然后进行跳跃连接
2. 第二个Block
- 首先特征图经过
Layer Norm
层,经过SW-MSA
,然后进行跳跃连接 - 连接后的特征图再次经过
Layer Norm
层,经过全连接层MLP
,然后进行跳跃连接
从上面四步可以看出Swin Transformer Block
清晰的执行步骤,其中比较难理解的是W-MSA
和SW-MSA
,下面我们详细介绍二者,并介绍由二者引出的一些细节
(1)first block
包含两个主要模块,W-MSA和MLP
输入为(B, 3136, 96)
输出为(B, 3136, 96)
W-MSA
window partition
W-MSA在第一个block中,这一步没有滑动窗,输入为(B, 3136, 96),为了后面的sefl-attention操作,需要将特征图划分为一个个窗口的形式,首先经历了一个window partition操作,变为(64B, 7, 7, 96)
怎么计算的呢?输入为batch=B,3136=56*56,特征图有96个,将每个特征图56*56分为7*7的窗口,一共能分8*8=64个,乘上之前的B就是64B了,就是说将特征图分为(7, 7)的小窗,然后把所有的小窗拿出来一共有64B个,示例图如下
为什么要进行window partition?在Vision Transformer中,我们将图片分成了一个个patch(也就是左边的图),在进行MSA时,任何一个patch都要与其他所有的patch都进行attention,当patch的大小固定时,计算量与图片的大小成平方增长。Swin Transformer中采用了W-MSA,也就是window的形式,不同的window包含了相同数量的patch,只对window内部进行MSA,当图片大小增大时,计算量仅仅是呈线性增加(只增加了图片多余部分的计算量,比如之前是224的图像,现在是256的图像,只多了256-224=32像素的计算部分),下面详细介绍window attention部分
window attention
将窗口分配完成后就可以执行attention操作了,首先我们将维度变为(64B, 49, 96),进行attention操作时,我们需要qkv三个变量,transformer是通过linear函数来实现的:nn.Linear(dim, dim * 3, bias=qkv_bias)
,通过这个函数后,维度变为(64B, 49, 288),qkv分别占三分之一,也就是说qkv分别为(64B, 49, 96),第一个阶段的head为3,维度划分为(64B, 3, 49, 32)
此时qkv的值如下所示,这就是进行attention时qkv的维度
- q: (64B, 3, 49, 32)
- k: (64B, 3, 49, 32)
- v: (64B, 3, 49, 32)
接下来就是进行attention操作,熟悉transformer的同学肯定很容易理解
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d + B ) V Attention(Q,K,V) = SoftMax(\frac{QK^{T}}{\sqrt{d}}+B)V Attention(Q,K,V)=SoftMax(dQKT+B)V
注意这里加了一个偏置B,在最后会详细介绍相对位置偏置(Relative Position Bias)的原理
window reverse
所有attention步骤执行完之后就可以回到attention之前的维度(64B, 7, 7, 96),然后我们经过一个window reverse操作就可以回到window partition之前的状态了,即(B, 56, 56, 96)。window reverse就是window partition的逆过程
总结:这里总结一下W-MSA所做的事情,首先进行window partition操作,维度从(B, 3136, 96)也就是(B, 56, 56, 96)变为(64B, 7, 7, 96);随后进行attention操作,先经过一个线性层维度变为三倍来为qkv分别赋值(64B, 49, 96*3): qkv(64B, 49, 96),随后根据multi-head操作在将qkv分别分成三份,(64B, 3, 49, 32),最后进行attention操作(即上面的公式),然后通过window reverse回到最初的状态(B, 56, 56, 96),也就是(B, 3136, 96),下面图示了这一阶段的过程
MLP
输入为(4, 3136, 96)
输出为(4, 3136, 96)
再经过第二个Block之前要先经过一个MLP,其中结构为
Linear(96, 96*4)
——GELU()
——Linear(96*4, 96)
——Dropout
最终维度并不发生变化
(2)second block
包含两个主要模块,SW-MSA和MLP
输入为(4, 3136, 96)
输出为(4, 3136, 96)
与第一个Block唯一不同的地方就是SW-MSA模块,所以这里仅讲解此模块
SW-MSA
与W-MSA不同的地方在于这个模块存在滑动,所以叫做shifted window,滑动的距离为win_size//2
在这里也就是7//2=3
,这里用image(4, 4)
win(2, 2)
shift=1
来图示他的shift以及mask机制
这里先给出Github上有助于理解此机制的提问:链接
为什么要用mask机制呢,Swin Transformer与Vision Transformer相比虽然降低了计算量,但缺点是同一个window里面的patch可以交互,window与window之间无法交互,所以考虑滑动窗的方法,如上图所示,滑动过后为了保证图片的完整性,我们将上面和左边的图补齐到右边,这又带来了一个缺点:图片的右端和补齐的图片本身并不是相邻的,所以无法交互,解决办法就是mask
Swin Transformer的mask机制是说,如果相互交互的patch属于同一个区域(对应于上图的颜色),那么就可以正常交互,如果不是同一个区域(对应于上图的不同颜色),那么他们交互之后就需要加上一个很大的负值,这样通过softmax层之后本来不能交互的那个像素就变成0了,这就是mask机制
这里附上Github上讨论的一个源码,由此可以直接看到mask是如何运行的,这个代码与我上述的图是对应的
import torch
import torch.nn as nn
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 2
shift_size = 1
H, W = 4, 4
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.unsqueeze(1).unsqueeze(0)
print(attn_mask)
"""
tensor([[[[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]],
[[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]]],
[[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]]],
[[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]]]])
"""
Patch Merging
在每个Stage结束的阶段都有一个Patch Merging
的过程,这个过程会让输入进行降维,同时通道变为原来的二倍,用一个图来清晰的展示此过程,图示如下
上面说到过Swin的作用是使得patch交互的区域变大,另一种使其变大的方法就是这里提到的Patch Merging,在每个阶段结束之后,将特征图的维度减半,channel加倍,在保持patch和window不变的情况下相当于变相提高了patch和window的感受野,使其效果更好
到这里Swin Transformer的一个stage就已经讲完了,其余的Stage和上面讲述的完全一致,为了再次强化Swin Transformer的整个流程,下面是整个流程展示,其中加粗部分为我们已经走过的流程(这里依然是Swin-Tiny版本)
input-->patch partition-->linear embedding
stage1 W-MSA-->MLP-->SW-MSA-->MLP
stage2 W-MSA-->MLP-->SW-MSA-->MLP
stage3 W-MSA-->MLP-->SW-MSA-->MLP
*3
stage4 W-MSA-->MLP-->SW-MSA-->MLP-->tail process
三、Supplement
Relative Position Bias
到这里整个Swin Transformer
就已经讲完了,还记得attention中加了一个bias B
吗,这里对其进行讲解,依旧取win=2
,如下所示
这里的相对位置偏置这样理解,在窗口中任意选定一个坐标,遵循左+右-上+下-
的原则,可以发现当我们将左上角的值为(0, 0)
时,他右边的位置为(0, -1)
减了1,下面的位置为(-1, 0)
也减了1,同理将其他位置设为(0, 0)
时,结果分别如图所示
然后我们将其展开,执行:行列分别加M-1=2-1=1
,行标乘2M-1=3
,最终可以得到下图,然后需要注意的是最大值为8,也就是说一共有9个索引,为什么有四个像素,按理来说为4*4=16
个位置,只有9个索引呢?这是因为是相对位置编码位置有重复,又因为win=2
,所以行和列的索引均为[-1, 1]
,一共3*3=9
种组合,即九个相对位置索引,因此相对位置索引表一共有9个数字,如下图所示
其中上面是索引表(9个数),下面是索引后的结果
为了更清晰的认识相对位置偏置,这里给出一个简单的example
# relative_position_bias_table (1, 9)
relative_position_bias_table = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80, 90])
# relative_position_index (4, 4)
window_size = [2, 2]
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
# index (4, 4)
table = relative_position_bias_table[relative_position_index.view(-1)].view(window_size[0]*window_size[1], window_size[0]*window_size[1], -1)
table = table.permute(2, 0, 1).contiguous().unsqueeze(0)
print("relative_position_index\n", relative_position_index)
print(table)
'''
relative_position_index
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])
tensor([[[[50, 40, 20, 10],
[60, 50, 30, 20],
[80, 70, 50, 40],
[90, 80, 60, 50]]]])
'''
到这里Swin Transformer
就讲完啦,但是因为写的比较仓促有一些地方讲的不够细致,还有关于FLOPs运算的细节没有讲到,后面有时间会再补充~
更多推荐
所有评论(0)