tensorflow 之模型的保存与加载(一)

怎样让通过训练的神经网络模型得以复用?

本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读.

 1 #!/usr/bin/env python3
 2 #-*- coding:utf-8 -*-
 3 ############################
 4 #File Name: saver.py
 5 #Brief:
 6 #Author: frank
 7 #Mail: [email protected]
 8 #Created Time:2018-06-22 22:12:52
 9 ############################
10
11 """
12 checkpoint                  #保存所有的模型文件列表
13 my_test_model.ckpt.data-00000-of-00001
14 my_test_model.ckpt.index
15 my_test_model.ckpt.meta     #保存计算图的结构信息,即神经网络的结构
16 """
17
18
19 import tensorflow as tf
20
21 #声明两个变量并计算它们的和
22 v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
23 v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
24 result = v1 + v2
25
26 init_op = tf.global_variables_initializer()
27
28 #声明tf.train.Saver类用于保存模型
29 saver = tf.train.Saver()
30
31 with tf.Session() as sess:
32     sess.run(init_op)
33     #将模型保存到指定路径
34     saver.save(sess,"my_test_model.ckpt")

模型的加载方法:

#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore.py
#Brief:
#Author: frank
#Mail: [email protected]
#Created Time:2018-06-22 22:34:16
############################                             

import tensorflow as tf                                  

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
print(v1)
result = v1 + v2
print(result)                                            

saver = tf.train.Saver()                                 

with tf.Session() as sess:
    saver.restore(sess, "my_test_model.ckpt")
    print(sess.run(result))                              

#运行结果:
#<tf.Variable ‘v1:0‘ shape=(1,) dtype=float32_ref>
#Tensor("add:0", shape=(1,), dtype=float32)
#[3.]

上面的过程中还是定义了 图的结构,有点重复了,那么可不可以直接从已保存的ckpt中加载图呢?

import tensorflow as tf                                                

saver = tf.train.import_meta_graph("my_test_model.ckpt.meta")          

with tf.Session() as sess:
    saver.restore(sess, "my_test_model.ckpt")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

上面的程序,默认保存和加载了计算图中的全部变量,但有时可能只需要保存或加载部分变量。因为并不是所有隐藏层的参数需要重新训练。

具体怎么做呢?且听下回分解

原文地址:https://www.cnblogs.com/black-mamba/p/9226705.html

时间: 2024-10-25 18:13:27

tensorflow 之模型的保存与加载(一)的相关文章

tensorflow 之模型的保存与加载(三)

前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的过程. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ############################ #File Name: save_restore.py #Brief: #Author: frank #Mail: [email protect

tensorflow模型的保存与加载

模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有的状态都保存起来),saved_model(更通用的方式,以固定模型格式保存,该格式是各种语言通用的) 具体使用方法如下: # 保存模型 model.save_weights('./checkpoints/my_checkpoint') # 加载模型 model = keras.create_mod

(sklearn)机器学习模型的保存与加载

需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是怎么在实际业务中保存模型,不至于每次都来训练,在预测. 解决方案: 机器学习-训练模型的保存与恢复(sklearn)python /模型持久化 /模型保存 /joblib /模型恢复在做模型训练的时候,尤其是在训练集上做交叉验证,通常想要将模型保存下来,然后放到独立的测试集上测试,下面介绍的是Pyt

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存.加载的上述三类环境不同,加载时会出错.就去研究了一下,做了实验,得出以下结论: 多/单GPU训练保存模型参数.CPU加载使用模型 #保存 PATH = 'cifar_net.pth' torch.save(net.module.state_dict(), PATH) #加载 net = Net()

解析OBJ模型并将其加载到Unity3D场景中

??各位朋友,大家好,欢迎大家关注我的博客,我是秦元培,我的博客地址是http://qinyuanpei.com.今天想和大家交流的是解析obj模型并将其加载到Unity3D场景中,虽然我们知道Unity3D是可以直接导入OBJ模型的,可是有时候我们并不能保证我们目标客户知道如何使用Unity3D的这套制作流程,可能对方最终提供给我们的就是一个模型文件而已,所以这个在这里做这个尝试想想还是蛮有趣的呢,既然如此,我们就选择在所有3D模型格式中最为简单的OBJ模型来一起探讨这个问题吧! 关于OBJ模

sklearn训练模型的保存与加载

使用joblib模块保存于加载模型 在机器学习的过程中,我们会进行模型的训练,最常用的就是sklearn中的库,而对于训练好的模型,我们当然是要进行保存的,不然下次需要进行预测的时候就需要重新再进行训练.如果数据量小的话,那再重新进行训练是没有问题的,但是如果数据量大的话,再重新进行训练可能会花费很多开销,这个时候,保存好已经训练的模型就显得特别重要了.我们可以使用sklearn中的joblib模块进行保存与加载. from sklearn.externals import joblib # 保

转 tensorflow模型保存 与 加载

使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in

TensorFlow的模型保存与加载

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf #tensorboard --logdir="./" def linearregression(): with tf.variable_scope("original_data"): X = tf.random_normal([100,1],mean=0.0,stddev=1.0) y_true = tf.matmul

机器学习之保存与加载.pickle模型文件

import pickle from sklearn.externals import joblib from sklearn.svm import SVC from sklearn import datasets #定义一个分类器 svm = SVC() iris = datasets.load_iris() X = iris.data y = iris.target #训练模型 svm.fit(X,y) #1.保存成Python支持的文件格式Pickle #在当前目录下可以看到svm.pic