tf.GradientTape详解

参考文献:https://blog.csdn.net/guanxs/article/details/102471843

在TensorFlow 1.x静态图时代,我们知道每个静态图都有两部分,一部分是前向图,另一部分是反向图。反向图就是用来计算梯度的,用在整个训练过程中。而TensorFlow 2.0默认是eager模式,每行代码顺序执行,没有了构建图的过程(也取消了control_dependency的用法)。但也不能每行都计算一下梯度吧?计算量太大,也没必要。因此,需要一个上下文管理器(context manager)来连接需要计算梯度的函数和变量,方便求解同时也提升效率。
       举个例子:计算y=x^2在x = 3时的导数:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6

例子中的watch函数把需要计算梯度的变量x加进来了。GradientTape默认只监控由tf.Variable创建的traiable=True属性(默认)的变量。上面例子中的x是constant,因此计算梯度需要增加g.watch(x)函数。当然,也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False就行了(一般用不到)。

GradientTape也可以嵌套多层用来计算高阶导数,例如:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x
  dy_dx = gg.gradient(y, x)      # y’ = 2*x = 2*3 =6
d2y_dx2 = g.gradient(dy_dx, x)  # y’’ = 2

另外,默认情况下GradientTape的资源在调用gradient函数后就被释放,再次调用就无法计算了。所以如果需要多次计算梯度,需要开启persistent=True属性,例如:

x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y
dz_dx = g.gradient(z, x)  # z = y^2 = x^4, z’ = 4*x^3 = 4*3^3
dy_dx = g.gradient(y, x)  # y’ = 2*x = 2*3 = 6
del g  # 删除这个上下文tape

最后,一般在网络中使用时,不需要显式调用watch函数,使用默认设置,GradientTape会监控可训练变量,例如:

with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)

这样即可计算出所有可训练变量的梯度,然后进行下一步的更新。对于TensorFlow 2.0,推荐大家使用这种方式计算梯度,并且可以在eager模式下查看具体的梯度值。

根据上面的例子说一下tf.GradientTape这个类的常见的属性和函数,更多的可以去官方文档来看。

__init__(persistent=False,watch_accessed_variables=True)
作用:创建一个新的GradientTape
参数:

persistent: 布尔值,用来指定新创建的gradient tape是否是可持续性的。默认是False,意味着只能够调用一次gradient()函数。
watch_accessed_variables: 布尔值,表明这个gradien tap是不是会自动追踪任何能被训练(trainable)的变量。默认是True。要是为False的话,意味着你需要手动去指定你想追踪的那些变量。
比如在上面的例子里面,新创建的gradient tape设定persistent为True,便可以在这个上面反复调用gradient()函数。

watch(tensor)
作用:确保某个tensor被tape追踪

参数:

tensor: 一个Tensor或者一个Tensor列表
gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
作用:根据tape上面的上下文来计算某个或者某些tensor的梯度
参数:

target: 被微分的Tensor或者Tensor列表,你可以理解为经过某个函数之后的值
sources: Tensors 或者Variables列表(当然可以只有一个值). 你可以理解为函数的某个变量
output_gradients: a list of gradients, one for each element of target. Defaults to None.
unconnected_gradients: a value which can either hold ‘none’ or ‘zero’ and alters the value which will be returned if the target and sources are unconnected. The possible values and effects are detailed in ‘UnconnectedGradients’ and it defaults to ‘none’.
返回:
一个列表表示各个变量的梯度值,和source中的变量列表一一对应,表明这个变量的梯度。

原文地址:https://www.cnblogs.com/SupremeBoy/p/12246528.html

时间: 2024-07-30 19:56:16

tf.GradientTape详解的相关文章

tensorflow 的tf.where详解

最近在用到数据筛选,观看代码中有tf.where()的用法,不是很常用,也不是很好理解.在这里记录一下 1 tf.where( 2 condition, 3 x=None, 4 y=None, 5 name=None 6 ) Return the elements, either from x or y, depending on the condition. 理解:where嘛,就是要根据条件找到你要的东西. condition:条件,是一个boolean x:数据 y:同x维度的数据. 返回

jar打包命令详解

:如何把 java 程序编译成 .exe 文件.通常回答只有两种,一种是说,制作一个可执行的 JAR 文件包,就可以像.chm 文档一样双击运行了:而另一种回答,则是使用 JET 来进行编译.但是 JET 是要用钱买的,而且,据说 JET 也不是能把所有的 Java 程序都编译成执行文件,性能也要打些折扣.所以,使用制作可执行 JAR 文件包的方法就是最佳选择了,何况它还能保持 Java 的跨平台特性.先来看看什么是 JAR 文件包: 1. JAR 文件包 JAR 文件就是 Java Archi

[转] - JAR文件包及jar命令详解 ( MANIFEST.MF的用法 )

常常在网上看到有人询问:如何把 java 程序编译成 .exe 文件.通常回答只有两种,一种是制作一个可执行的 JAR 文件包,然后就可以像. chm 文档一样双击运行了:而另一种是使用 JET 来进行 编译.但是 JET 是要用钱买的,而且据说 JET 也不是能把所有的 Java 程序 都编译成执行文件,性能也要打些折扣.所以,使用制作可执行 JAR 文件包的方法就是最佳选择了,何况它还能保持 Java 的跨平台特性. 下面就来看看什么是 JAR 文件包吧: 1. JAR 文件包 JAR 文件

tar命令详解

tar命令详解 -c: 建立压缩档案 -x:解压 -t:查看内容 -r:向压缩归档文件末尾追加文件 -u:更新原压缩包中的文件 这五个是独立的命令,压缩解压都要用到其中一个,可以和别的命令连用但只能用其中一个. 下面的参数是根据需要在压缩或解压档案时可选的. -z:有gzip属性的 -j:有bz2属性的 -Z:有compress属性的 -v:显示所有过程 -O:将文件解开到标准输出 参数-f是必须的 -f: 使用档案名字,切记,这个参数是最后一个参数,后面只能接档案名. # tar -cf al

tar 解压缩命令详解(转)

tar 解压缩命令详解 -c: 建立压缩档案 -x:解压-t:查看内容-r:向压缩归档文件末尾追加文件-u:更新原压缩包中的文件 这五个是独立的命令,压缩解压都要用到其中一个,可以和别的命令连用但只能用其中一个.下面的参数是根据需要在压缩或解压档案时可选的. -z:有gzip属性的-j:有bz2属性的-Z:有compress属性的-v:显示所有过程-O:将文件解开到标准输出 下面的参数-f是必须的 -f: 使用档案名字,切记,这个参数是最后一个参数,后面只能接档案名. # tar -cf all

tar 指令详解

tar 解压缩命令 tar -c: 建立压缩档案 -x:解压 -t:查看内容 -r:向压缩归档文件末尾追加文件 -u:更新原压缩包中的文件 这五个是独立的命令,压缩解压都要用到其中一个,可以和别的命令连用但只能用其中一个.下面的参数是根据需要在压缩或解压档案时可选的. -z:有gzip属性的 -j:有bz2属性的 -Z:有compress属性的 -v:显示所有过程 -O:将文件解开到标准输出 下面的参数-f是必须的 -f: 使用档案名字,切记,这个参数是最后一个参数,后面只能接档案名. # ta

海量数据处理算法总结【超详解】

1. Bloom Filter [Bloom Filter]Bloom Filter(BF)是一种空间效率很高的随机数据结构,它利用位数组很简洁地表示一个集合,并能判断一个元素是否属于这个集合.它是一个判断元素是否存在集合的快速的概率算法.Bloom Filter有可能会出现错误判断,但不会漏掉判断.也就是Bloom Filter判断元素不再集合,那肯定不在.如果判断元素存在集合中,有一定的概率判断错误.因此,Bloom Filter不适合那些“零错误”的应用场合. 而在能容忍低错误率的应用场合

Linux下的压缩解压缩命令详解

zip -r myfile.zip ./*将当前目录下的所有文件和文件夹全部压缩成myfile.zip文件,-r表示递归压缩子目录下所有文件. 2.unzipunzip -o -d /home/sunny myfile.zip把myfile.zip文件解压到 /home/sunny/-o:不提示的情况下覆盖文件:-d:-d /home/sunny 指明将文件解压缩到/home/sunny目录下: 3.其他zip -d myfile.zip smart.txt删除压缩文件中smart.txt文件z

16位汇编第五讲各种指令详解第一讲

汇编指令详解 8080指令详解 1.8086系统下,Inter指令系统共有117条指令(看似很多,分一下类) 1.数据传送类指令(专门传送数据的) 2.算术运算类指令(加减乘除的运算的) 3.位操作类指令(或  异货 与 -.) 4.串操作类指令 (内存拷贝,内存连续地址拷贝的操作) 5.控制转移类指令(跳转,比如C语言的Goto) 6.处理机控制类指令(计算机的待机 ,重启 等等,让CPU待机睡眠的指令) 学习指令的注意事项 1.指令的功能,也就是这个指令可以实现什么操作.通常的话,指令就是指