文章目录

多头注意力

给定一个Query(查询)和一系列的Key-Value对一起映射出一个输出。包括下面三个关键性步骤:

  • 将Query与Key进行相似性度量
  • 将求得的相似性度量进行缩放标准化
  • 将权重与value进行加权

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的h组不同的线性投影(linear projections)来变换查询、键和值。 然后,这h组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这h个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention)。对于h个注意力汇聚输出,每一个注意力汇聚都被称作为一个头(head)。下图展示了使用全连接层来实现可学习的线性变换的多头注意力。

在这里插入图片描述

在这里插入图片描述
上图为多头注意力:多个头连接然后线性变换

多头注意力机制则是单头注意力机制的进化版,把每次attention运算分组(头)进行,能够从多个维度提炼特征信息。具体原理可以参看相关的科普文章,下面是Pytorch实现。

import torch.nn as nn
class MHSA(nn.Module):
    def __init__(self, num_heads, dim):
        super().__init__()
        # Q, K, V 转换矩阵,这里假设输入和输出的特征维度相同
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.num_heads = num_heads
	
    def forward(self, x):
        B, N, C = x.shape
        # 生成转换矩阵并分多头
        q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        k = self.k(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        v = self.k(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        
        # 点积得到attention score
        attn = q @ k.transpose(2, 3) * (x.shape[-1] ** -0.5)
        attn = attn.softmax(dim=-1)
        
        # 乘上attention score并输出
        v = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)
        return v

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐