Scaled dot-product Attention、Self-Attention辨析
一、Scaled dot-product Attention有两个序列X、YX、YX、Y:序列XXX提供查询信息QQQ,序列YYY提供键、值信息K、VK、VK、V。Q∈Rx_len×in_dimQ\in R^{{x\_len}\times {in\_dim}}{}Q∈Rx_len×in_dimK∈Ry_len×in_dimK\in R^{{y\_len}\times {in\_dim}}{}K∈R
一、Scaled dot-product Attention
有两个序列
X
、
Y
X、Y
X、Y:序列
X
X
X提供查询信息
Q
Q
Q,序列
Y
Y
Y提供键、值信息
K
、
V
K、V
K、V。
Q
∈
R
x
_
l
e
n
×
i
n
_
d
i
m
Q\in R^{{x\_len}\times {in\_dim}}{}
Q∈Rx_len×in_dim
K
∈
R
y
_
l
e
n
×
i
n
_
d
i
m
K\in R^{{y\_len}\times {in\_dim}}{}
K∈Ry_len×in_dim
V
∈
R
y
_
l
e
n
×
o
u
t
_
d
i
m
V\in R^{{y\_len}\times {out\_dim}}{}
V∈Ry_len×out_dim
Scaled dot-product Attention计算公式:
s
o
f
t
m
a
x
(
Q
K
T
i
n
_
d
i
m
)
V
softmax(\frac{QK^T}{\sqrt {in\_dim}})V
softmax(in_dimQKT)V
二、Self Attention
序列
X
X
X与自己进行注意力计算。序列
X
X
X同时提供查询信息
Q
Q
Q,键、值信息
K
、
V
K、V
K、V。这时
x
_
l
e
n
=
y
_
l
e
n
、
i
n
_
d
i
m
=
o
u
t
_
d
i
m
x\_len=y\_len、in\_dim=out\_dim
x_len=y_len、in_dim=out_dim,则
Q
、
K
、
V
Q、K、V
Q、K、V矩阵维度相同:
Q
∈
R
x
_
l
e
n
×
i
n
_
d
i
m
Q\in R^{{x\_len}\times {in\_dim}}{}
Q∈Rx_len×in_dim
K
∈
R
x
_
l
e
n
×
i
n
_
d
i
m
K\in R^{{x\_len}\times {in\_dim}}{}
K∈Rx_len×in_dim
V
∈
R
x
_
l
e
n
×
i
n
_
d
i
m
V\in R^{{x\_len}\times {in\_dim}}{}
V∈Rx_len×in_dim
三、pytorch实现
def sequence_mask(X, valid_len, value=0):
"""Mask irrelevant entries in sequences."""
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
def masked_softmax(X, valid_lens):
"""通过在最后一个轴上遮蔽元素来执行 softmax 操作"""
# `X`: 3D张量, `valid_lens`: 1D或2D 张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 在最后的轴上,被遮蔽的元素使用一个非常大的负值替换,从而其 softmax (指数)输出为 0
X =sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
class DotProductAttention(nn.Module):
"""Scaled dot product attention."""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# Shape of `queries`: (`batch_size`, no. of queries, `d`)
# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set `transpose_b=True` to swap the last two dimensions of `keys`
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
attention = DotProductAttention(dropout=0.5)
batch_size = 2
x_len , y_len = 3, 8
in_dim, out_dim = 2, 10
Q = torch.ones((batch_size, x_len, in_dim))
K = torch.ones((batch_size, y_len, in_dim))
V = torch.ones((batch_size, y_len, out_dim))
soft_attention_ans = attention(Q, K, V)
self_attention_ans = attention(Q, Q, Q)
print(soft_attention_ans.shape)
print(self_attention_ans.shape)
torch.Size([2, 3, 10])
torch.Size([2, 3, 2])
四、参考
更多推荐
所有评论(0)