save = tf.train.Saver()
通过save. save() 实现数据的加载
通过save.restore() 实现数据的导出
第一步: 数据的载入
import tensorflow as tf #创建变量 v1 = tf.Variable(tf.random_normal([1, 2], name=‘v1‘)) v2 = tf.Variable(tf.random_normal([2, 3], name=‘v2‘)) #初始化变量 init_op = tf.global_variables_initializer() #构建训练模型的保存 saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) print(‘V1:‘, sess.run(v1)) print(‘V2:‘, sess.run(v2)) # saver.save(保存内容, 保存路径) saver_path = saver.save(sess, ‘save/model.ckpt‘) print(‘Model saved in file:‘, saver_path)
第二步: 数据的导出
import tensorflow as tf# v1,v2的设定,主要是看看输出的v1是哪个v1 v1 = tf.Variable(tf.random_normal([1, 2]), name=‘v1‘) v2 = tf.Variable(tf.random_normal([2, 3]), name=‘v2‘) # 构建保存模型 saver = tf.train.Saver() with tf.Session() as sess: # 重新加载模型(重新赋予名字, 加载的路径) saver.restore(sess, ‘save/model.ckpt‘) print(‘V1:‘, sess.run(v1)) print(‘V2:‘, sess.run(v2)) print(‘Model restored‘)
原文地址:https://www.cnblogs.com/my-love-is-python/p/9570286.html
时间: 2024-10-04 05:49:15