RNN中的梯度消失爆炸原因

RNN中的梯度消失/爆炸原因

梯度消失/梯度爆炸是深度学习中老生常谈的话题,这篇博客主要是对RNN中的梯度消失/梯度爆炸原因进行公式层面上的直观理解。

首先,上图是RNN的网络结构图,\((x_1, x_2, x_3, …, )\)是输入的序列,\(X_t\)表示时间步为\(t\)时的输入向量。假设我们总共有\(k\)个时间步,用第\(k\)个时间步的输出\(H_k\)作为输出(实际上每个时间步都有输出,这里仅考虑\(H_k\)),用\(E_k\)表示损失。

其中,\(C_{t}=\tanh \left(W_{c} C_{t-1}+W_{x} X_{t}\right)\)

从上式可以看出 \(W_x\)和\(W_c\)其实是差不多的,记\(W=[W_c, W_x]\),那么求偏导可以得到:

\(\begin{aligned} \frac{\partial E_{k}}{\partial W}=& \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}} \frac{\partial C_{k}}{\partial C_{k-1}} \ldots \frac{\partial C_{2}}{\partial C_{1}} \frac{\partial C_{1}}{\partial W}=\\ & \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \frac{\partial C_{t}}{\partial C_{t-1}}\right) \frac{\partial C_{1}}{\partial W} \end{aligned}\)

其中的累乘部分为:

\(\begin{aligned} \frac{\partial C_{t}}{\partial c_{t-1}}=& \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot \frac{d}{d C_{t-1}}\left[W_{c} C_{t-1}+W_{x} X_{t}\right]=\\ & \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot W_{c} \end{aligned}\)

将该式代入上式有:

\(\frac{\partial E_{k}}{\partial W}=\frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot W_{c}\right) \frac{\partial c_{1}}{\partial W}\)

观察这个式子,和上篇文章中一样,因为链式法则,出现了累乘项,因为tanh的导数 <= 1,所以,当k很大的时候,上式的值是趋向于0的。(<1的数多次相乘),也就是:

\(\Pi_{t=2}^{k} \tanh ^{\prime}\left(W_{c} C_{t-1}+w_{x} X_{t}\right) \cdot W_{c} \rightarrow 0,\) so \(\frac{\partial E_{k}}{\partial W} \rightarrow 0\)

此时,权重更新公式:

\(W \leftarrow W-\alpha \frac{\partial E_{k}}{\partial W} \approx W\)

也就是说,RNN很容易出现梯度消失现象,使得参数更新缓慢,甚至是停止更新。

原文地址:https://www.cnblogs.com/Elaine-DWL/p/11240209.html

时间: 2024-07-30 07:08:44

RNN中的梯度消失爆炸原因的相关文章

【神经网络和深度学习】笔记 - 第五章 深度神经网络学习过程中的梯度消失问题

之前的章节,我们利用一个仅包含一层隐藏层的简单神经网络就在MNIST识别问题上获得了98%左右的准确率.我们于是本能会想到用更多的隐藏层,构建更复杂的神经网络将会为我们带来更好的结果. 就如同在进行图像模式识别的时候,第一层的神经层可以学到边缘特征,第二层的可以学到更复杂的图形特征,例如三角形,长方形等,第三层又会识别更加复杂的图案.这样看来,多层的结构就会带来更强大的模型,进行更复杂的识别. 那么在这一章,就试着训练这样的神经网络来看看对结果有没有什么提升.不过我们发现,训练的过程将会出现问题

LSTM缓解梯度消失的原因

\(c_{t}=c_{t-1} \otimes \sigma\left(W_{f} \cdot\left[H_{t-1}, X_{t}\right]\right) \oplus \tanh \left(W_{c} \cdot\left[H_{t-1}, X_{t}\right]\right) \otimes \sigma\left(W_{i} \cdot\left[H_{t-1}, X_{t}\right]\right)\) 反向传播公式: \(\begin{aligned} \frac{\pa

深度学习解决局部极值和梯度消失问题方法简析(转载)

转载:http://blog.sina.com.cn/s/blog_15f0112800102wojj.html 这篇文章关于对深度CNN中BP梯度消失的问题的做了不错的解析,可以看一下: 多层感知机解决了之前无法模拟异或逻辑的缺陷,同时更多的层数也让网络更能够刻画现实世界中的复杂情形.理论上而言,参数越多的模型复杂度越高,“容量”也就越大,也就意味着它能完成更复杂的学习任务.多层感知机给我们带来的启示是,神经网络的层数直接决定了它对现实的刻画能力——利用每层更少的神经元拟合更加复杂的函数.但是

【深度学习系列】DNN中梯度消失和梯度爆炸的原因推导

DNN中梯度消失和梯度爆炸的原因推导 因为手推涉及很多公式,所以这一截图放出. 原文地址:https://www.cnblogs.com/Elaine-DWL/p/11140917.html

DL4NLP——神经网络(二)循环神经网络:BPTT算法步骤整理;梯度消失与梯度爆炸

网上有很多Simple RNN的BPTT算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再这样表示了,因为下标需要用做表示时刻. 典型的Simple RNN结构如下: 图片来源:[3] 约定一下记号: 输入序列 $\textbf x_{(1:T)} =(\textbf x_1,\textbf x_2,...,\textbf x_T)$,每个时刻的值都是一个维数是词表大小的one-hot列向量: 标记序列 $\textbf y_{(1:T)} =(\textbf

梯度消失与梯度爆炸

https://blog.csdn.net/qq_25737169/article/details/78847691 产生消失的梯度问题的原因 先看一个极简单的深度神经网络:每一层都只有一个单一的神经元.如下图: 代价函数C对偏置b1的偏导数的结果计算如下: 先看一下sigmoid 函数导数的图像: 该导数在σ′(0) = 1/4时达到最高.现在,如果我们使用标准方法来初始化网络中的权重,那么会使用一个均值为0 标准差为1 的高斯分布.因此所有的权重通常会满足|wj|<1.从而有wjσ′(zj)

梯度消失(vanishing gradient)和梯度爆炸(exploding gradient)

转自https://blog.csdn.net/guoyunfei20/article/details/78283043 神经网络中梯度不稳定的根本原因:在于前层上的梯度的计算来自于后层上梯度的乘积(链式法则).当层数很多时,就容易出现不稳定.下边3个隐含层为例: 其b1的梯度为: 加入激活函数为sigmoid,则其导数如下图: sigmoid导数σ'的最大值为1/4.同常一个权值w的取值范围为abs(w) < 1,则:|wjσ'(zj)| < 1/4,从而有: 从上式可以得出结论:前层比后层

RNN 训练时梯度爆炸的理解

梯度爆炸 比方说当前点刚好在悬崖边上, 这个时候计算这个点的斜率就会变得非常大, 我们跟新的时候是按 斜率 × 学习率 来的, 那么这时候参数的跟新就会非常非常大, loss也会非常大 应对办法就是 当斜率超过某个值比如15时, 设定斜率为15. 造成梯度爆炸的原因并不是来自激活函数 --- sigmoid , 如果把激活函数换为 ReLU 通常模型表现会更差 梯度消失 可以理解为 RNN 把 weight 变化的程度放大了 原文地址:https://www.cnblogs.com/larkii

对于梯度消失和梯度爆炸的理解

一.梯度消失.梯度爆炸产生的原因 假设存在一个网络结构如图: 其表达式为: 若要对于w1求梯度,根据链式求导法则,得到的解为: 通常,若使用的激活函数为sigmoid函数,其导数: 这样可以看到,如果我们使用标准化初始w,那么各个层次的相乘都是0-1之间的小数,而激活函数f的导数也是0-1之间的数,其连乘后,结果会变的很小,导致梯度消失.若我们初始化的w是很大的数,w大到乘以激活函数的导数都大于1,那么连乘后,可能会导致求导的结果很大,形成梯度爆炸. 当然,若对于b求偏导的话,其实也是一个道理: