TensorFlow的梯度裁剪

在较深的网络,如多层CNN或者非常长的RNN,由于求导的链式法则,有可能会出现梯度消失(Gradient Vanishing)或梯度爆炸(Gradient Exploding )的问题。

原理

问题:为什么梯度爆炸会造成训练时不稳定而且不收敛?

梯度爆炸,其实就是偏导数很大的意思。回想我们使用梯度下降方法更新参数:

损失函数的值沿着梯度的方向呈下降趋势,然而,如果梯度(偏导数)很大话,就会出现函数值跳来跳去,收敛不到最值的情况,如图:

当然出现这种情况,其中一种解决方法是,将学习率αα设小一点,如0.0001。

这里介绍梯度裁剪(Gradient Clipping)的方法,对梯度进行裁剪,论文提出对梯度的L2范数进行裁剪,也就是所有参数偏导数的平方和再开方。

TensorFlow代码

方法一:

optimizer = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)
grads = optimizer.compute_gradients(loss)
for i, (g, v) in enumerate(grads):
    if g is not None:
        grads[i] = (tf.clip_by_norm(g, 5), v)  # 阈值这里设为5
train_op = optimizer.apply_gradients(grads)

其中

optimizer.compute_gradients()返回的是正常计算的梯度,是一个包含(gradient, variable)的列表。

tf.clip_by_norm(t, clip_norm)返回裁剪过的梯度,维度跟t一样。

不过这里需要注意的是,这里范数的计算不是根据全局的梯度,而是一部分的。

方法二:

optimizer = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)
grads, variables = zip(*optimizer.compute_gradients(loss))
grads, global_norm = tf.clip_by_global_norm(grads, 5)
train_op = optimizer.apply_gradients(zip(grads, variables))

这里是计算全局范数,这才是标准的。不过缺点就是会慢一点,因为需要全部梯度计算完之后才能进行裁剪。

总结

当你训练模型出现Loss值出现跳动,一直不收敛时,除了设小学习率之外,梯度裁剪也是一个好方法。

然而这也说明,如果你的模型稳定而且会收敛,但是效果不佳时,那这就跟学习率和梯度爆炸没啥关系了。因此,学习率的设定和梯度裁剪的阈值并不能提高模型的准确率。

原文地址:https://www.cnblogs.com/zongfa/p/9737698.html

时间: 2024-08-30 15:53:03

TensorFlow的梯度裁剪的相关文章

tensorflow 梯度裁剪

gvs = optimizer.compute_gradients(loss) # 计算出梯度和变量值 capped_gvs = [(tf.clip_by_value(grad, -5e+10, 5e+10), var) for grad, var in gvs] # 梯度裁剪 train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step) # 梯度下降 原文地址:https://www.cnblogs.com/

如何用TensorFlow图像处理函数裁剪图像?

当给定大量不同质量的训练数据时,CNN往往能够很好地工作. –图像能够通过可视化的方式,传达复杂场景所蕴含的某种目标主题. –在Stanford Dogs数据集中,重要的是图像能够以可视化的方式,突出图片中狗的重要性. –一幅狗位于画面中心的图像,会被认为比狗作为背景的图像更有价值. 并非所有数据集都拥有最有价值的图像.下面所示的两幅图像,按照假设,该数据集本应突出不同的狗的品种 左图突出的是一条典型的墨西哥无毛犬的重要属性,而右图是两个参加聚会的人,在逗一条墨西哥无毛犬.右图中充斥了大量的无关

tensorflow实现Minist手写体识别

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #下载MINIST数据集mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #表示输入任意数量的MNIST图像,每一张图展平成784维的向量#placeholder是占位符,在训练时指定x = tf.placeholder(tf.float32, [None,

Tensorflow快速入门2--实现手写数字识别

Tensorflow快速入门2–实现手写数字识别 环境: 虚拟机ubuntun16.0.4 Tensorflow(仅使用cpu版) Tensorflow安装见: http://blog.csdn.net/yhhyhhyhhyhh/article/details/54429034 或者: http://www.tensorfly.cn/tfdoc/get_started/os_setup.html 本文将利用Tensorflow以softmax回归和卷积神经网络两种模型简单测试MNIST数据集,快

TensorFlow教程03:针对机器学习初学者的MNIST实验——回归的实现、训练和模型评估

实现回归模型 为了用python实现高效的数值计算,我们通常会使用函数库,比如NumPy,会把类似矩阵乘法这样的复杂运算使用其他外部语言实现.不幸的是,从外部计算切换回Python的每一个操作,仍然是一个很大的开销.如果你用GPU来进行外部计算,这样的开销会更大.用分布式的计算方式,也会花费更多的资源用来传输数据. TensorFlow也把复杂的计算放在python之外完成,但是为了避免前面说的那些开销,它做了进一步完善.Tensorflow不单独地运行单一的复杂计算,而是让我们可以先用图描述一

Tensorflow快餐教程(1) - 30行代码搞定手写识别

去年买了几本讲tensorflow的书,结果今年看的时候发现有些样例代码所用的API已经过时了.看来自己维护一个保持更新的Tensorflow的教程还是有意义的.这是写这一系列的初心. 快餐教程系列希望能够尽可能降低门槛,少讲,讲透. 为了让大家在一开始就看到一个美好的场景,而不是停留在漫长的基础知识积累上,参考网上的一些教程,我们直接一开始就直接展示用tensorflow实现MNIST手写识别的例子.然后基础知识我们再慢慢讲. Tensorflow安装速成教程 由于Python是跨平台的语言,

学习笔记CB014:TensorFlow seq2seq模型步步进阶

神经网络.<Make Your Own Neural Network>,用非常通俗易懂描述讲解人工神经网络原理用代码实现,试验效果非常好. 循环神经网络和LSTM.Christopher Olah http://colah.github.io/posts/2015-08-Understanding-LSTMs/ . seq2seq模型基于循环神经网络序列到序列模型,语言翻译.自动问答等序列到序列场景,都可用seq2seq模型,用seq2seq实现聊天机器人的原理 http://suriyade

什么是梯度爆炸?怎么解决?

梯度的衰减是有连续乘法导致的,如果在连续乘法中出现一个非常大的值,最后计算出的梯度就会很大,就想当优化到断崖处是,会获得一个很大的梯度值,如果以这个梯度值进行更新,那么这次迭代的步长就很大,可能会一下子飞出了合理的区域. 解决的方法是: 梯度裁剪: 把沿梯度下降方向的步长限制在一个范围之内,计算出来的梯度的步长的范数大于这个阈值的话,就以这个范数为基准做归一化,使这个新的的梯度的范数等于这个阈值就行了. 梯度检查: 梯度计算很不稳定,使用梯度检查来检查梯度计算是否出了错误. 通过解析的梯度值与计

使用tensorflow操作MNIST数据

本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有Hello World,机器学习入门有MNIST.在此节,我将训练一个机器学习模型用于预测图片里面的数字. MNIST 是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用做深度学习的入门样例.而Tensorflow的封装让MNIST数据集变得更加方便.MNIST是NIST数据集的一个子集,它包含了60000张图片作为训练数据,10000张图片作为测试数据.在MNIST数据集中的