Tensorflow模型的 暂存 恢复 微调 保存 加载

  • 暂存模型(*.index为参数名称,*.meta为模型图,*.data*为参数)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

sess.close()
  • 暂存模型(同一模型多次保存可以不保存模型图节省时间)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL1_NAME), write_meta_graph=False)
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL2_NAME), write_meta_graph=False)

sess.close()
  • 恢复模型(手动生成网络则不需要*.meta文件)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

print(sess.run([weights]))

sess.close()
  • 恢复模型(从*.meta文件生成网络)
tf.reset_default_graph()

saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))

sess.close()
  • 恢复模型(可以在一个文件夹下保存多次模型,checkpoint文件会自动记录所有模型名称和最后一次记录模型名称)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(MODEL_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)

print(sess.run([weights]))

sess.close()
  • 微调模型(恢复之前训练模型的部分参数,加上新参数,继续训练)
def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True):
    ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
    ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
    if not include_global_step:
        ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
    vars_in_ckpt = {}
    for variable_name, variable in sorted(variables.items()):
        if variable_name in ckpt_vars_to_shape_map:
            if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
                vars_in_ckpt[variable_name] = variable
    return vars_in_ckpt

tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
other_weights = tf.Variable(tf.zeros([10, 10]))

variables_to_init = tf.global_variables()
variables_to_init_dict = {var.op.name: var for var in variables_to_init}
available_var_map = get_variables_available_in_checkpoint(variables_to_init_dict,
    "%s/%s" % (MODEL_DIR, MODEL_NAME), include_global_step=False)
tf.train.init_from_checkpoint("%s/%s" % (MODEL_DIR, MODEL_NAME), available_var_map)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))

sess.close()
  • 保存模型(二进制模型)
from tensorflow.python.framework.graph_util import convert_variables_to_constants

tf.reset_default_graph()

saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

graph_out = convert_variables_to_constants(sess, sess.graph_def, output_node_names=[‘weights‘])
with tf.gfile.GFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME), "wb") as output:
    output.write(graph_out.SerializeToString())

sess.close()
  • 加载模型(二进制模型)
tf.reset_default_graph()

sess = tf.Session()
with tf.gfile.FastGFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME),‘rb‘) as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def,name=‘‘)
sess.run(tf.global_variables_initializer())

print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))

sess.close()

参考文献:

https://blog.csdn.net/loveliuzz/article/details/81661875

https://www.cnblogs.com/bbird/p/9951943.html

https://blog.csdn.net/gzj_1101/article/details/80299610

原文地址:https://www.cnblogs.com/jhc888007/p/11620821.html

时间: 2024-08-24 22:16:25

Tensorflow模型的 暂存 恢复 微调 保存 加载的相关文章

torch保存加载模型

保存模型 torch.save(my_model.state_dict(), "params.pkl") 加载模型 先初始化model网络结构 model.load_state_dict(torch.load("params.pkl")) 原文地址:https://www.cnblogs.com/rise0111/p/11621640.html

jvm内存模型,java类从编译到加载到执行的过程,jvm内存分配过程

一.jvm内存模型 JVM 内存模型主要分为堆.程序计数器.方法区.虚拟机栈和本地方法栈 1.堆 1.1.堆是 JVM 内存中最大的一块内存空间. 1.2.该内存被所有线程共享,几乎所有对象和数组都被分配到了堆内存中. 1.3.堆被划分为新生代和老年代,新生代又被进一步划分为 Eden 和 Survivor 区,最后 Survivor 由 From Survivor 和 To Survivor 组成. 2.程序计数器(Program Counter Register) 程序计数器是一块很小的内存

Docker实用技巧(一):镜像的备份/保存/加载/删除

首先需要理解,这里的镜像是指image,而container是容器,是image的一个启动. 镜像备份: 备份使用commit命令,相当于是将正在运行的container保存为一个image 使用方法如下: 实例: 最后跟的那个backup就相当于之后image的repository,当然这里也可以 backup:test,此时test就是tag. 运行image命令,查看已经有此image: 镜像保存: save命令用于保存image,如果想把备份好的image发送给别人,就需要保存,dock

jquery easyui datagrid 保存/加载自定义配置每列属性

直接附上源代码 <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title>Format DataGrid Columns - jQuery EasyUI Demo</title> <link rel="stylesheet" type="text/css" href="css/material/ea

优化Flash中的3D模型加载

来自:Kid's Zone 最近在做一个公司的Flash3D页游项目,遇到了这个问题,前前后后断断续续也优化了一段时间,觉得还是有必要记录一下一些优化的心得. Flash中加载资源一个最大的问题在于难以使用另外的线程加载资源.诚然Flash有Worker线程,但Worker存在以下几个问题: 1. 使用Worker要求客户的FlashPlayer播放器版本不能过低. 2. 不同Worker之间传递数据手段非常少,缺乏共享内存.使用ByteArray共享数据的话需要先把数据序列化成AMF格式,无论

sklearn训练模型的保存与加载

使用joblib模块保存于加载模型 在机器学习的过程中,我们会进行模型的训练,最常用的就是sklearn中的库,而对于训练好的模型,我们当然是要进行保存的,不然下次需要进行预测的时候就需要重新再进行训练.如果数据量小的话,那再重新进行训练是没有问题的,但是如果数据量大的话,再重新进行训练可能会花费很多开销,这个时候,保存好已经训练的模型就显得特别重要了.我们可以使用sklearn中的joblib模块进行保存与加载. from sklearn.externals import joblib # 保

双亲委派模型,类的加载机制,搞定大厂高频面试题

看过这篇文章,大厂面试你「双亲委派模型」,硬气的说一句,你怕啥? 读该文章姿势 打开手头的 IDE,按照文章内容及思路进行代码跟踪与思考 手头没有 IDE,先收藏,回头看 (万一哪次面试问了呢) 需要查看和拷贝代码,点击文章末尾出「阅读原文」 文章内容相对较长,所以添加了目录,如果你希望对 Java 的类加载过程有个更深入的了解,同时增加自己的面试技能点,请耐心读完...... 双亲委派模型 在介绍这个Java技术点之前,先试着思考以下几个问题: 为什么我们不能定义同名的 String 的 ja

tensorflow模型的保存与恢复

模型保存后产生四个文件,分别是: |--models| |--checkpoint| |--.meta| |--.data| |--.index .meta保存的是图的结构 checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表. .data和.index保存的是变量值. tensorflow常用的模型保存方法: best_str = '' if best_loss is None or valid_loss < best_los

转 tensorflow模型保存 与 加载

使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in