Tensorflow Learning1 模型的保存和恢复

CKPT->pb

Demo

解析

tensor name 和 node name 的区别

Pb 的恢复

CKPT->pb

tensorflow的模型保存有两种形式:

1. ckpt:可以恢复图和变量,继续做训练

2. pb : 将图序列化,变量成为固定的值,,只可以做inference;不能继续训练

Demo

  1 def freeze_graph(input_checkpoint,output_graph):
  2
  3     ‘‘‘
  4     :param input_checkpoint:
  5     :param output_graph: PB模型保存路径
  6     :return
  7       void
  8     ‘‘‘
  9
 10     # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 11     # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 12
 13     # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 14     output_node_names = "InceptionV3/Logits/SpatialSqueeze" # 如果是多个输出节点,使用 ‘,’号隔开
 15
 16     ############################     Step1: 从ckpt中恢复图:     #############################################
 17     saver = tf.train.import_meta_graph(input_checkpoint + ‘.meta‘, clear_devices=True)
 18     graph = tf.get_default_graph() # 获得默认的图, 可以省略
 19     input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图,可以省略
 20
 21     with tf.Session() as sess: # 会使用默认的图 作为当前的图
 22         saver.restore(sess, input_checkpoint) #恢复图并得到数据
 23
 24         ########################     Step2: 创建持久化对象,指定sess,图、以及输出的序列化节点信息    ##############
 25         output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
 26             sess=sess,
 27             input_graph_def=input_graph_def,# 等于:sess.graph_def
 28             output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 29         #########################    Step3: 模型持久化   #######################################################
 30         with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 31             f.write(output_graph_def.SerializeToString()) #序列化输出
 32         print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 33         # for op in graph.get_operations():
 34
 35         #     print(op.name, op.values())
 36
 37
 38 ########################### 调用方式 ################################
 39 # 输入ckpt模型路径
 40 input_checkpoint=‘models/model.ckpt-10000‘
 41 # 输出pb模型的路径
 42 out_pb_path="models/pb/frozen_model.pb"
 43 # 调用freeze_graph将ckpt转为pb
 44 freeze_graph(input_checkpoint,out_pb_path)

解析

函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。

freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

在保存pb的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称;

tensor name 和 node name 的区别

node name 是 图 的节点,里面包含了很多操作和tensor

tensor 是 node 里面的一个组成部分;

以input 为例,“input:0”是张量的名称,而"input"表示的是节点的名称

PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"input:0"

原文地址:https://www.cnblogs.com/greentomlee/p/11494383.html

时间: 2024-07-30 14:15:58

Tensorflow Learning1 模型的保存和恢复的相关文章

tensorflow 1.0 学习:模型的保存与恢复(Saver)

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver() 在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型.如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置

AI - TensorFlow - 示例05:保存和恢复模型

保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models 在训练期间保存检查点 在训练期间或训练结束时自动保存检查点.权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式).可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断 检查点回调用法:创建检查点回调,训练模型并将ModelC

tensorflow模型的保存与恢复

模型保存后产生四个文件,分别是: |--models| |--checkpoint| |--.meta| |--.data| |--.index .meta保存的是图的结构 checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表. .data和.index保存的是变量值. tensorflow常用的模型保存方法: best_str = '' if best_loss is None or valid_loss < best_los

tensorflow 之模型的保存与加载(一)

怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. 1 #!/usr/bin/env python3 2 #-*- coding:utf-8 -*- 3 ############################ 4 #File Name: saver.py 5 #Brief: 6 #Author: frank 7 #Mail: [email protected] 8 #Created Time:2018-06-22 22:12:52 9 ###

tensorflow 之模型的保存与加载(三)

前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的过程. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ############################ #File Name: save_restore.py #Brief: #Author: frank #Mail: [email protect

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess

tensorflow模型的保存与加载

模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有的状态都保存起来),saved_model(更通用的方式,以固定模型格式保存,该格式是各种语言通用的) 具体使用方法如下: # 保存模型 model.save_weights('./checkpoints/my_checkpoint') # 加载模型 model = keras.create_mod

(sklearn)机器学习模型的保存与加载

需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是怎么在实际业务中保存模型,不至于每次都来训练,在预测. 解决方案: 机器学习-训练模型的保存与恢复(sklearn)python /模型持久化 /模型保存 /joblib /模型恢复在做模型训练的时候,尤其是在训练集上做交叉验证,通常想要将模型保存下来,然后放到独立的测试集上测试,下面介绍的是Pyt

tensorflow机器学习模型的跨平台上线

在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法优化的PMML文件大多数时候很笨拙,因此本文我们专门讨论下tensorflow机器学习模型的跨平台上线的方法. 1. tensorflow模型的跨平台上线的备选方案 tensorflow模型的跨平台上线的备选方案一般有三种:即PMML方式,tensorflow serving方式,以及跨语言API方