LSTM 简介
LSTM长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我
LSTM
长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
- 遗忘门:决定应丢弃或保留哪些信息。来自前一个隐藏状态的信息和当前输入的信息同时传递到 sigmoid 函数中去,输出值介于 0 和 1 之间,越接近 0 意味着越应该丢弃,越接近 1 意味着越应该保留
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
ft=σ(Wf⋅[ht−1,xt]+bf)
- 输入门:输入门用于更新细胞状态。首先将前一层隐藏状态的信息和当前输入的信息传递到 sigmoid 函数中去。将值调整到 0~1 之间来决定要更新哪些信息。0 表示不重要,1 表示重要。其次还要将前一层隐藏状态的信息和当前输入的信息传递到 tanh 函数中去,创造一个新的侯选值向量。最后将 sigmoid 的输出值与 tanh 的输出值相乘,sigmoid 的输出值将决定 tanh 的输出值中哪些信息是重要且需要保留下来的
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
C
t
~
=
tanh
(
W
C
⋅
[
h
t
−
1
,
x
t
]
+
b
C
)
i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ \tilde{C_t} = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)
it=σ(Wi⋅[ht−1,xt]+bi)Ct~=tanh(WC⋅[ht−1,xt]+bC)
- 细胞状态:前一层的细胞状态与遗忘向量逐点相乘。如果它乘以接近 0 的值,意味着在新的细胞状态中,这些信息是需要丢弃掉的。然后再将该值与输入门的输出值逐点相加,将神经网络发现的新信息更新到细胞状态中去。至此,就得到了更新后的细胞状态
C
t
=
f
t
⨀
C
t
−
1
+
i
t
⨀
C
t
~
注
:
⨀
为
H
a
d
a
m
a
r
d
p
r
o
d
u
c
t
,
即
对
应
点
相
乘
C_t = f_t \bigodot C_{t-1} + i_t \bigodot \tilde{C_t} \\ 注:\bigodot 为 \ Hadamard \ product,即对应点相乘
Ct=ft⨀Ct−1+it⨀Ct~注:⨀为 Hadamard product,即对应点相乘
- 输出门:输出门用来确定下一个隐藏状态的值,隐藏状态包含了先前输入的信息。首先,我们将前一个隐藏状态和当前输入传递到 sigmoid 函数中,然后将新得到的细胞状态传递给 tanh 函数。最后将 tanh 的输出与 sigmoid 的输出相乘,以确定隐藏状态应携带的信息。再将隐藏状态作为当前细胞的输出,把新的细胞状态和新的隐藏状态传递到下一个时间步长中去
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
h
t
=
o
t
⨀
tanh
(
C
t
)
o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \\ h_t = o_t \bigodot \tanh(C_t)
ot=σ(Wo⋅[ht−1,xt]+bo)ht=ot⨀tanh(Ct)
LSTM 变体
Peephole LSTM
f t = σ ( W f ⋅ [ C t − 1 , h t − 1 , x t ] + b f ) i t = σ ( W i ⋅ [ C t − 1 , h t − 1 , x t ] + b i ) C t ~ = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) C t = f t ⨀ C t − 1 + ( 1 − f t ) ⨀ C t ~ o t = σ ( W o ⋅ [ C t , h t − 1 , x t ] + b o ) h t = o t ⨀ tanh ( C t ) f_t = \sigma(W_f \cdot [C_{t-1}, h_{t-1}, x_t] + b_f) \\ i_t = \sigma(W_i \cdot [C_{t-1}, h_{t-1}, x_t] + b_i) \\ \tilde{C_t} = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \\ C_t = f_t \bigodot C_{t-1} + (1 - f_t) \bigodot \tilde{C_t} \\ o_t = \sigma(W_o \cdot [C_t, h_{t-1}, x_t] + b_o) \\ h_t = o_t \bigodot \tanh(C_t) ft=σ(Wf⋅[Ct−1,ht−1,xt]+bf)it=σ(Wi⋅[Ct−1,ht−1,xt]+bi)Ct~=tanh(WC⋅[ht−1,xt]+bC)Ct=ft⨀Ct−1+(1−ft)⨀Ct~ot=σ(Wo⋅[Ct,ht−1,xt]+bo)ht=ot⨀tanh(Ct)
GRU
z
t
=
σ
(
W
z
⋅
[
h
t
−
1
,
x
t
]
+
b
z
)
r
t
=
σ
(
W
r
⋅
[
h
t
−
1
,
x
t
]
+
b
r
)
h
t
~
=
tanh
(
W
h
⋅
[
r
t
⨀
h
t
−
1
,
x
t
]
,
b
h
)
h
t
=
(
1
−
z
t
)
⨀
h
t
−
1
+
z
t
⨀
h
t
~
z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \\ r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \\ \tilde{h_t} = \tanh(W_h \cdot [r_t \bigodot h_{t-1}, x_t], b_h) \\ h_t = (1 - z_t) \bigodot h_{t-1} + z_t \bigodot \tilde{h_t}
zt=σ(Wz⋅[ht−1,xt]+bz)rt=σ(Wr⋅[ht−1,xt]+br)ht~=tanh(Wh⋅[rt⨀ht−1,xt],bh)ht=(1−zt)⨀ht−1+zt⨀ht~
注:一般这里可以不用考虑偏置,原论文中也没有偏置
LSTM 简单例子
import torch
import torch.nn as nn
# 输入数据 x 的向量维数 10, 设定 LSTM 隐藏层的特征维度 20, 此 model 用 2 个 LSTM 层
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10) # input(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20) # h_0(num_layers * num_directions, batch, hidden_size)
c0 = torch.randn(2, 3, 20) # c_0(num_layers * num_directions, batch, hidden_size)
# output(seq_len, batch, hidden_size * num_directions)
# h_n(num_layers * num_directions, batch, hidden_size)
# c_n(num_layers * num_directions, batch, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0))
# torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])
print(output.size(), hn.size(), cn.size())
补充:RNN, LSTM & GRU、pytorch中lstm参数与案例理解、LSTM这一篇就够了、从RNN到LSTM再到GRU、LSTM论文翻译-《Understanding LSTM Networks》、Convolutional LSTM Network
更多推荐
所有评论(0)