RNN(Recurrent Neural Network)
RNN(Recurrent Neural Network)
多层反馈RNN(Recurrent neural Network、循环神经网络)神经网络是一种节点定向连接成环的人工神经网络。这种网络的内部状态可以展示动态时序行为。不同于前馈神经网络的是,RNN可以利用它内部的记忆来处理任意时序的输入序列,这让它可以更容易处理如不分段的手写识别、语音识别等。
如下图所示是循环神经网络的典型结构,其中包含一个从隐含层到隐含层的连接(x、h分别是输入和隐含层),为了更方便我们理解其的工作原理,可将其展开成右边的结构。
RNN常用于序列数据的标注或预测,例如预测一句话的下一个可能的单词。下图是RNN的典型结构,图中x、o分别是系统的输入和输出,h是隐含神经元,y是训练集中给出的真实值,例如o是预测的下一个单词,而y是真实出现的下一个单词,L是损失估计。U、V、W分别是权重,需要通过训练得到。
我们首先对RNN的正向传播(forward propagation)方程进行描述,其中隐含元的激活函数(activation function)也可以有其他选择。
(1)
假设损失函数取为负对数似然,则有:
(2)
为了对模型参数进行估计和优化,需要基于反向传播(backward propagation)算法针对损失函数从右至左计算相对各参数的梯度,因此算法的时间和空间复杂度是O(t),并称该算法为back-propagation through time/BPTT。
现在我们推导RNN反向传播算法中的梯度。
首先从损失估计节点开始计算如下:
假设损失函数是负对数似然,则相对于输出节点的梯度计算如下:
相对于隐含节点的梯度计算如下:
(3)
相对模型全部参数的梯度计算如下:
(4)
RNN的numpy实现可以参照这篇文章,其实现了预测句中的下一个词。
核心函数主要有三个:正向传播、损失估计、反向传播,代码如下。
正向传播和损失估计分别对应了公式(1)和(2)。由于RNN反向传播中存在梯度消失或爆炸的问题,因此在反向传播算法的实现中对反向传播路径长度进行限制,并借助相邻隐含元h(t)和h(t+1)之间的递推公式(3)不断进行递推,对模型参数进行优化,代码实现如下。
典型的RNN除了可以将隐含节点进行连接,也可以将输出节点和隐含节点进行连接,如下图所示。然而由于在递推过程中上一级向下一级的输入被编码为输出向量,因此状态传播收到限制,相比隐含节点之间连接的情况,所能描述的函数空间缩小,但另一方面,由于在训练阶段隐含节点间不存在依赖,因此可以采用Teacher Forcing的方法进行并行训练。
上述RNN的输入x是一个序列,当x是一个定长向量时,可以采用下面的结构,例如可以应用于图像字幕(image captioning)。
上述RNN都仅存在单向的“因果”链,如果需要使用双向的“因果”链,还可采用双向RNN结构(Bidirectional RNN),即改变上面预测下一个词的例子为预测一句话中填空中的一个词,此时可以同时利用该填空的前一部分和后一部分单词序列同时进行估计。双向RNN结构如下。
采用RNN还可以实现序列到序列的编解码器结构(Encoder-Decoder Sequence-to-Sequence Architecture)。一个例子是实现字符串的加法,代码参见Keras提供的example:addition_rnn.py。例如输入"535+61",输出"596"。序列的编解码器结构如下。
代码中编码器为一层RNN,最后的隐含节点的输出对应图中的C,然后复制DIGIT+1次作为解码器的输入,解码器包含LAYERS层RNN,最后叠加一层全连接作为输出。Keras的代码如下。
除了单层RNN结构,常用结构还有深层循环网络和回归神经网络,其结构如下所示。
RNN中典型问题是长期依赖(Long-Term Dependencies)。RNN的优势在于能够通过先前的信息来指导当前的任务,例如预测下一个词,句子是"the clouds are in the _",容易预测出"sky",该例子中,相关信息和预测词之间的距离很小,然而,在一个更复杂的场景下,例如句子"I grew up in France…I speak fluent _",为了正确预测出下一个词所需相关信息的跨度非常大,然而RNN在计算过程中后面时间节点相对前面时间节点的感知力下降,因此会导致方法失效。为了解决这个问题,已有不少研究提出了不同策略,例如:Echo State Networks、Leaky Units and Multiple Time Scales、以及基于Gate的Long Short Term Memory /LSTM等。
相关连接
- http://www.deeplearningbook.org/contents/rnn.html
- http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/
- https://github.com/mqliang/rnnpython/blob/master/rnnnumpy.py
- https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py