TensorFlow是一个通过计算图的形式表述计算机的编程系统
TensorFlow程序一般分为两个阶段,第一个阶段需要定义计算图中所有的计算(变量)
第二个阶段为执行计算
如以下代码
import tensorflow as tf # 第一阶段定义所有的计算 a = tf.constant([1, 2], name=‘a‘) b = tf.constant([1, 2], name=b‘) result = a + b # 第二阶段,执行计算 # 创建一个会话 sess = tf.Session() #运行会话执行计算 sess.run(result) # 关闭会话 sess.close()
通过tf.get_default_graph函数可以获取当前默认的计算图,通过a.graph可以查看张量所属的计算图
如果没有特意指定a.graph等于默认的计算图,下面的代码输出为True
print(a.graph is tf.get_default_graph()) # 输出为True
除了使用默认的计算图,TensorFlow支持通过tf.Graph函数来生成新的计算图,不同计算图上的张量和运算不会共享
import tensorflow as tf g1 = tf.Graph() with g1.as default(): # 在计算图g1中定义变量‘v‘,并设置初始值为0 v = tf.get_variable(‘v‘, shap=[1], initializer=tf.zeros_initializer) g2 = tf.Graph() with g2.as default(): # 在计算图g1中定义变量‘v‘,并设置初始值为1 v = tf.get_variable(‘v‘, shap=[1], initializer=tf.oness_initializer) # 在计算图g1中读取变量‘v‘的取值 with tf.Session(graph=g1) as sess: tf.global_variables_initializer().run() with tf.variable_scope("", reuse=True): # 在计算图g1中,变量v的取值应该为0,下面应输出[0.] print(sess.run(tf.get_variable(‘v‘))) # 在计算图g2中读取变量‘v‘的取值 with tf.Session(graph=g2) as sess: # 初始化全局变量 tf.global_variables_initializer().run() with tf.variable_scope("", reuse=True): # 在计算图g2中,变量v的取值应该为1,下面应输出[1.] print(sess.run(tf.get_variable(‘v‘)))
另外计算图还可以通过tf.Graph.device函数指定运行的设备
g = tf.Graph() # 指定计算运行设的设备,指定到gpu0上 with g.device(‘/gpu:0‘): result a + b
原文地址:https://www.cnblogs.com/lyh-vip/p/10505733.html
时间: 2024-10-11 15:38:41