一、Scaled dot-product Attention

有两个序列 X 、 Y X、Y XY:序列 X X X提供查询信息 Q Q Q,序列 Y Y Y提供键、值信息 K 、 V K、V KV Q ∈ R x _ l e n × i n _ d i m Q\in R^{{x\_len}\times {in\_dim}}{} QRx_len×in_dim K ∈ R y _ l e n × i n _ d i m K\in R^{{y\_len}\times {in\_dim}}{} KRy_len×in_dim V ∈ R y _ l e n × o u t _ d i m V\in R^{{y\_len}\times {out\_dim}}{} VRy_len×out_dim
Scaled dot-product Attention计算公式: s o f t m a x ( Q K T i n _ d i m ) V softmax(\frac{QK^T}{\sqrt {in\_dim}})V softmax(in_dim QKT)V
在这里插入图片描述

二、Self Attention

序列 X X X与自己进行注意力计算。序列 X X X同时提供查询信息 Q Q Q,键、值信息 K 、 V K、V KV。这时 x _ l e n = y _ l e n 、 i n _ d i m = o u t _ d i m x\_len=y\_len、in\_dim=out\_dim x_len=y_lenin_dim=out_dim,则 Q 、 K 、 V Q、K、V QKV矩阵维度相同: Q ∈ R x _ l e n × i n _ d i m Q\in R^{{x\_len}\times {in\_dim}}{} QRx_len×in_dim K ∈ R x _ l e n × i n _ d i m K\in R^{{x\_len}\times {in\_dim}}{} KRx_len×in_dim V ∈ R x _ l e n × i n _ d i m V\in R^{{x\_len}\times {in\_dim}}{} VRx_len×in_dim
在这里插入图片描述

三、pytorch实现

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行 softmax 操作"""
    # `X`: 3D张量, `valid_lens`: 1D或2D 张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 在最后的轴上,被遮蔽的元素使用一个非常大的负值替换,从而其 softmax (指数)输出为 0
        X =sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # Shape of `queries`: (`batch_size`, no. of queries, `d`)
    # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
    # Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
    # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set `transpose_b=True` to swap the last two dimensions of `keys`
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

attention = DotProductAttention(dropout=0.5)

batch_size = 2
x_len , y_len = 3, 8
in_dim, out_dim = 2, 10

Q = torch.ones((batch_size, x_len, in_dim))
K = torch.ones((batch_size, y_len, in_dim))
V = torch.ones((batch_size, y_len, out_dim))

soft_attention_ans = attention(Q, K, V)
self_attention_ans = attention(Q, Q, Q)
print(soft_attention_ans.shape)
print(self_attention_ans.shape)
torch.Size([2, 3, 10])
torch.Size([2, 3, 2])

四、参考

[1] https://zh-v2.d2l.ai/chapter_preliminaries/index.html

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐