LSTM

RNN的主要问题是在BPTT过程中会出现梯度消失或者爆炸,事实上,一般来说我们会更关注梯度消失的问题,其实梯度消失不仅仅是计算机精度造成的,我们也可以从这个角度来理解,从BPTT的推导我们知道,距离比较远的序列数据在当前时刻的BPTT计算时对应的参数指数较大,当参数小于1,就会导致比较旧的数据对于梯度更新的贡献十分少,这就相当于"忘记"了这些数据。所以,虽然我们也可以通过截断的BPTT避免梯度爆炸或者梯度消失,但是无法长期记忆这个问题没有得到解决,而长短期记忆网络(LSTM)正是针对这个问题而提出的。

首先以下是一个经典的LSTM结构图:
ポジショニングマップ
ポジショニングマップ
我们一步步从零开始推导出这个结构。LSTM最初的想法就是,既然原来的RNN没有办法长期记忆,那么就想办法在模型中加一个模块用于长期记忆:
ポジショニングマップ
为什么这个模块可以进行长期记忆,主要是因为它没有神经元,也就是不包含模型的参数,不会涉及到模型的反向传播,因而避免了之前提到的梯度消失等问题。这个模块主要就是一些数据(向量或者矩阵),或者形象一点说,就是模型的记忆内容、模型的信息。

在长期的记忆输入到模型之后,模型首先要做的就是分析有没有什么需要忘记,又有没有什么新的信息需要记住,首先,模型根据当前时刻的输入数据以及上一时刻的隐层输出,判断在长期记忆中是否有需要遗忘的信息:
ポジショニングマップ
当前数据以及上一时刻的隐层输出被输入到sigmoid函数,得到一个[0,1]之间的结果,最后再与长期记忆中的信息进行按位的乘法操作。这里的sigmoid函数就是整个模型的重点了,我们怎么理解它和记忆或遗忘之间的联系呢。首先,sigmoid的输出为0,就意味着我们需要完全舍弃之前的信息,1是代表着我们要完全保留之前的信息。所以,一个经过训练的模型,可以根据当前的数据输入以及历史数据,输入到sigmoid之后得到一个合理的数值,然后再与长期记忆进行运算,决定应该丢弃什么信息。

决定了长期记忆中应该遗忘什么信息之后,第二步就是分析当前时刻有没有值得加入的新信息,这个信息就是我们原来的RNN的输出:
ポジショニングマップ
ポジショニングマップ
注意,这里除了正常地以一个正常的RNN结构计算输出之外,同时再次利用一个sigmoid函数,分析当前的RNN输出,有多少部分是值得被长期记忆的,也就是说,并不是所有的信息都值得被长期记忆,这很重要,也很合理。得到了需要被长期记忆的重要信息之后,再通过按位加法添加到长期记忆中。

此时,在长期记忆中,就包含了所有在当前时刻有意义的历史信息,以及新追加的有意义的信息,这时候,我们就可以把这部分数据输入到激活函数tanh中,计算输出:
ポジショニングマップ
注意,在图中的结构中,当我们计算了激活函数的输出后,还需要与sigmoid的输出做一次按位乘法,也就是说,我们最终的输出还需要进行一次筛选,或许是认为不是所有的输出信息都和本次训练相关。

事实上,LSTM还有很多种不同的变体,但是我觉得LSTM的核心有两点,一个是整个模型包含长期记忆和短期记忆两条分支,第二点是使用sigmoid函数作为遗忘门,判断什么信息需要保留,什么信息需要舍弃,这才是基于原来的RNN实现的重大改进,而具体不同的结构,更多是因为不同的分析思路,实际上,人们经过对比发现,不同的变体在性能上不会出现十分明显的区别,只是针对某些具体的问题会有更优的结构。