论文地址:https://arxiv.org/pdf/2103.14030.pdf
代码地址:https://github.com/microsoft/Swin-Transformer

本文一共分为三个部分,首先介绍Swin Transformer的整体架构,随后会介绍每个模块的作用,中间会穿插部分代码。本文的主要目的还是希望能够将Swin Transformer解释清楚,然后结合官方代码来理解

一、Overall Architecture

首先给出论文中的Swin Transformer架构图

左边是Swin Transformer的全局架构,它包含Patch Partition、Linear Embedding、Swin Transformer Block、Patch Merging四大部分,这四大部分我们之后会进行详细的介绍

右边是Swin Transformer Block结构图,这是两个连续的Swin Transformer Block块,一个是W-MSA,一个是SW-MSA,也就是说根据Swin的Tiny版本,图中的Swin Transformer Block块为[2, 2, 6, 2],相对应的attention为:stage1 W-MSA-->SW-MSAstage2 W-MSA-->SW-MSAstage3 W-MSA-->SW-MSA-->W-MSA-->SW-MSA-->W-MSA-->SW-MSAstage4 W-MSA-->SW-MSA

二、Swin Transformer

下面的维度等均是基于Swin-T版本

1. Patch Partition & Linear Embedding

输入为(B, 3, 224, 224)
输出为(B, 96, 56, 56) —> (B, 96, 224/4=56, 224/4=56)

这两步在论文中其实就是一步实现,我们先来看paper中的解释:

  • Patch Partition,这一步是将输入的(H, W, 3)的图片分成(4, 4)的小块,分块后的图片大小为(H/4, W/4, 48)也就是文中所给的维度
  • Linear Embedding,在Tiny版本中,将分块后的图像映射到96维

在真正实现的时候paper使用了PatchEmbed函数将这两步结合起来,实际上也就是用了一个卷积的操作,卷积核大小为(4, 4),步长为4:nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

下面图示此过程

2. Basic Layer

在官方的代码库中,将Swin Transformer Block和Patch Merging合并成了一个,叫做Basic Layer,下面我们分别介绍这两者

Swin Transformer Block

输入为(B, 3136, 96)
输出为(B, 3136, 96)
就是把上一步的(4, 96, 56, 56)后两维度合并变为(4, 96, 3136),然后后两维互换(4, 3136, 96)

Swin Transformer Block的输入输出不变,每两个连续Block为一组,即一个Window Multi-head Self-Attention和一个Shifted Window Multi-head Self-Attention

下面是paper中的Swin Transformer Block示例图

从图中我们可以看出每两个连续Block块有四小步:

1. 第一个Block

  • 首先特征图经过Layer Norm层,经过W-MSA,然后进行跳跃连接
  • 连接后的特征图再次经过Layer Norm层,经过全连接层MLP,然后进行跳跃连接

2. 第二个Block

  • 首先特征图经过Layer Norm层,经过SW-MSA,然后进行跳跃连接
  • 连接后的特征图再次经过Layer Norm层,经过全连接层MLP,然后进行跳跃连接

从上面四步可以看出Swin Transformer Block清晰的执行步骤,其中比较难理解的是W-MSASW-MSA,下面我们详细介绍二者,并介绍由二者引出的一些细节

(1)first block

包含两个主要模块,W-MSA和MLP
输入为(B, 3136, 96)
输出为(B, 3136, 96)

W-MSA

window partition

W-MSA在第一个block中,这一步没有滑动窗,输入为(B, 3136, 96),为了后面的sefl-attention操作,需要将特征图划分为一个个窗口的形式,首先经历了一个window partition操作,变为(64B, 7, 7, 96)

怎么计算的呢?输入为batch=B,3136=56*56,特征图有96个,将每个特征图56*56分为7*7的窗口,一共能分8*8=64个,乘上之前的B就是64B了,就是说将特征图分为(7, 7)的小窗,然后把所有的小窗拿出来一共有64B个,示例图如下

为什么要进行window partition?在Vision Transformer中,我们将图片分成了一个个patch(也就是左边的图),在进行MSA时,任何一个patch都要与其他所有的patch都进行attention,当patch的大小固定时,计算量与图片的大小成平方增长。Swin Transformer中采用了W-MSA,也就是window的形式,不同的window包含了相同数量的patch,只对window内部进行MSA,当图片大小增大时,计算量仅仅是呈线性增加(只增加了图片多余部分的计算量,比如之前是224的图像,现在是256的图像,只多了256-224=32像素的计算部分),下面详细介绍window attention部分

window attention

将窗口分配完成后就可以执行attention操作了,首先我们将维度变为(64B, 49, 96),进行attention操作时,我们需要qkv三个变量,transformer是通过linear函数来实现的:nn.Linear(dim, dim * 3, bias=qkv_bias),通过这个函数后,维度变为(64B, 49, 288),qkv分别占三分之一,也就是说qkv分别为(64B, 49, 96),第一个阶段的head为3,维度划分为(64B, 3, 49, 32)

此时qkv的值如下所示,这就是进行attention时qkv的维度

  • q: (64B, 3, 49, 32)
  • k: (64B, 3, 49, 32)
  • v: (64B, 3, 49, 32)

接下来就是进行attention操作,熟悉transformer的同学肯定很容易理解

A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d + B ) V Attention(Q,K,V) = SoftMax(\frac{QK^{T}}{\sqrt{d}}+B)V Attention(Q,K,V)=SoftMax(d QKT+B)V

注意这里加了一个偏置B,在最后会详细介绍相对位置偏置(Relative Position Bias)的原理

window reverse

所有attention步骤执行完之后就可以回到attention之前的维度(64B, 7, 7, 96),然后我们经过一个window reverse操作就可以回到window partition之前的状态了,即(B, 56, 56, 96)。window reverse就是window partition的逆过程

总结:这里总结一下W-MSA所做的事情,首先进行window partition操作,维度从(B, 3136, 96)也就是(B, 56, 56, 96)变为(64B, 7, 7, 96);随后进行attention操作,先经过一个线性层维度变为三倍来为qkv分别赋值(64B, 49, 96*3): qkv(64B, 49, 96),随后根据multi-head操作在将qkv分别分成三份,(64B, 3, 49, 32),最后进行attention操作(即上面的公式),然后通过window reverse回到最初的状态(B, 56, 56, 96),也就是(B, 3136, 96),下面图示了这一阶段的过程

MLP

输入为(4, 3136, 96)
输出为(4, 3136, 96)

再经过第二个Block之前要先经过一个MLP,其中结构为

Linear(96, 96*4)——GELU()——Linear(96*4, 96)——Dropout

最终维度并不发生变化

(2)second block

包含两个主要模块,SW-MSA和MLP
输入为(4, 3136, 96)
输出为(4, 3136, 96)

与第一个Block唯一不同的地方就是SW-MSA模块,所以这里仅讲解此模块

SW-MSA

与W-MSA不同的地方在于这个模块存在滑动,所以叫做shifted window,滑动的距离为win_size//2在这里也就是7//2=3,这里用image(4, 4) win(2, 2) shift=1来图示他的shift以及mask机制

这里先给出Github上有助于理解此机制的提问:链接

为什么要用mask机制呢,Swin Transformer与Vision Transformer相比虽然降低了计算量,但缺点是同一个window里面的patch可以交互,window与window之间无法交互,所以考虑滑动窗的方法,如上图所示,滑动过后为了保证图片的完整性,我们将上面和左边的图补齐到右边,这又带来了一个缺点:图片的右端和补齐的图片本身并不是相邻的,所以无法交互,解决办法就是mask

Swin Transformer的mask机制是说,如果相互交互的patch属于同一个区域(对应于上图的颜色),那么就可以正常交互,如果不是同一个区域(对应于上图的不同颜色),那么他们交互之后就需要加上一个很大的负值,这样通过softmax层之后本来不能交互的那个像素就变成0了,这就是mask机制

这里附上Github上讨论的一个源码,由此可以直接看到mask是如何运行的,这个代码与我上述的图是对应的

import torch
import torch.nn as nn


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


window_size = 2
shift_size = 1
H, W = 4, 4
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))

cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.unsqueeze(1).unsqueeze(0)
print(attn_mask)

"""
tensor([[[[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],


         [[[   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.],
           [   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.]]],


         [[[   0.,    0., -100., -100.],
           [   0.,    0., -100., -100.],
           [-100., -100.,    0.,    0.],
           [-100., -100.,    0.,    0.]]],


         [[[   0., -100., -100., -100.],
           [-100.,    0., -100., -100.],
           [-100., -100.,    0., -100.],
           [-100., -100., -100.,    0.]]]]])
"""

Patch Merging

在每个Stage结束的阶段都有一个Patch Merging的过程,这个过程会让输入进行降维,同时通道变为原来的二倍,用一个图来清晰的展示此过程,图示如下

上面说到过Swin的作用是使得patch交互的区域变大,另一种使其变大的方法就是这里提到的Patch Merging,在每个阶段结束之后,将特征图的维度减半,channel加倍,在保持patch和window不变的情况下相当于变相提高了patch和window的感受野,使其效果更好

到这里Swin Transformer的一个stage就已经讲完了,其余的Stage和上面讲述的完全一致,为了再次强化Swin Transformer的整个流程,下面是整个流程展示,其中加粗部分为我们已经走过的流程(这里依然是Swin-Tiny版本)

input-->patch partition-->linear embedding
stage1 W-MSA-->MLP-->SW-MSA-->MLP
stage2 W-MSA-->MLP-->SW-MSA-->MLP
stage3 W-MSA-->MLP-->SW-MSA-->MLP *3
stage4 W-MSA-->MLP-->SW-MSA-->MLP-->tail process

三、Supplement

Relative Position Bias

到这里整个Swin Transformer就已经讲完了,还记得attention中加了一个bias B吗,这里对其进行讲解,依旧取win=2,如下所示

这里的相对位置偏置这样理解,在窗口中任意选定一个坐标,遵循左+右-上+下-的原则,可以发现当我们将左上角的值为(0, 0)时,他右边的位置为(0, -1)减了1,下面的位置为(-1, 0)也减了1,同理将其他位置设为(0, 0)时,结果分别如图所示

然后我们将其展开,执行:行列分别加M-1=2-1=1,行标乘2M-1=3,最终可以得到下图,然后需要注意的是最大值为8,也就是说一共有9个索引,为什么有四个像素,按理来说为4*4=16个位置,只有9个索引呢?这是因为是相对位置编码位置有重复,又因为win=2,所以行和列的索引均为[-1, 1],一共3*3=9种组合,即九个相对位置索引,因此相对位置索引表一共有9个数字,如下图所示

其中上面是索引表(9个数),下面是索引后的结果

为了更清晰的认识相对位置偏置,这里给出一个简单的example

# relative_position_bias_table (1, 9)
relative_position_bias_table = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80, 90])

# relative_position_index (4, 4)
window_size = [2, 2]
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

# index (4, 4)
table = relative_position_bias_table[relative_position_index.view(-1)].view(window_size[0]*window_size[1], window_size[0]*window_size[1], -1)
table = table.permute(2, 0, 1).contiguous().unsqueeze(0)

print("relative_position_index\n", relative_position_index)
print(table)

'''
relative_position_index
 tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])
tensor([[[[50, 40, 20, 10],
          [60, 50, 30, 20],
          [80, 70, 50, 40],
          [90, 80, 60, 50]]]])
'''

到这里Swin Transformer就讲完啦,但是因为写的比较仓促有一些地方讲的不够细致,还有关于FLOPs运算的细节没有讲到,后面有时间会再补充~

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐