LSTM是RNN的改进版本,主要的改进就是引入了长期记忆模块,对LSTM来说,每一个时刻,它需要考虑三件事,第一是长期记忆中应该忘记什么,第二是长期记忆中应该增加什么新的内容,第三是结合长期信息、上一时刻的信息和当前时刻的输入计算输出。
首先要知道,对模型来说,并没有所谓的记忆,一切都只是数字,模型只是为了让数字变得有意义。现在我们有一段很长的数字,比如是一个1*10000的矩阵,那么我们如何决定在这个矩阵中,什么信息才应该保留(数值保持不变),什么信息应该丢弃(变为0)。
LSTM的做法是计算一个同样尺寸的1*10000的矩阵,数值只为[0,1],按位相乘决定什么信息需要被遗忘,以及遗忘多少:
$$f_t = \sigma (W_f [h_{t-1}, x_t]+b_f)$$
上述公式的含义,就是分别对当前时刻和上一时刻的输入做线性变换,并相加得到一个1*10000的矩阵,并通过sigmoid函数把结果缩放到[0,1],最后把这个结果和长期记忆矩阵按位相乘,就遗忘了相对应的信息了。
接下来我们可以用类似的方式,计算当前的信息中有什么值得被长期记忆。首先也是用类似的方法计算信息保留程度:
$$i_t = \sigma (W_i [h_{t-1}, x_t]+b_i)$$
式子是一样的,只是参数不同,然后我们基于上一时刻和当前时刻的输入计算输出,这个输出可以理解成在当前时刻,我们根据过去和现在的信息总结提取的信息:
$$\tilde C_t = tanh(W_c [h_{t-1}, x_t] + b_c)$$
综合上述计算,就可以得到新的长期记忆应该包含什么内容:
$$C_t = f_t * C_{t-1} + i_t * \tilde C_t$$
最后,我们就可以基于新的长期记忆的内容,分析该时刻模型的输出是什么了,但是,考虑到并不是长期记忆中所有内容都和本次输出有关,所以我们会再次用老方法分析一次,到底长期记忆输出的结果中,有多少是相关的:
$$o_t = \sigma (W_o [h_{t-1}, x_t]+b_o)$$
$$h_t = o_t * tanh(C_t)$$
可以看到,LSTM最重要的部分就是通过上一时刻的输出以及当前时刻的输出,结合sigmoid函数分析应该记住什么,应该忘记什么,从而实现长期记忆,这个结构也被称为遗忘门。