tensorflow之freeze_gragh

主要了解下freeze_graph的用法

以及了解下freeze_graph_test的一些相关知识(据说具有很好的学习价值)

freeze_graph.py源码链接:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

freeze_graph_test.py源码链接:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py

tf模型的基本介绍

Tensorflow所有的文档格式都是基于Protocol Buffer,即protobuf

在文本文档中定义数据结构,protobuf工具生成C、Python和其他语言的类,这些类可以友好的加载、保存和方位数据

Tensorflow里的计算基础是Graph对象

它可以存储网络节点,每一个节点代表一个操作,并作为输入和输出相互链接在一起

GraphDef 类是ProtoBuf根据

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto

以此为基础定义创建的对象。Protobuf工具会解析此文本文档,并生成用户加载、存储和操控图定义的代码。

将文档加载到 grapf_def 变量中,就可以访问其中的数据

可以使用下面的代码来遍历这些节点,基本上重要的部分都是存储在节点中了。

12
graph_def = graph_pb2.GraphDef()for node in graph_def.node

每一个节点都是一个在node_def.proto定义的NodeDef对象,这些节点是Tensorflow图的基本构建块,每一个构建块都定义了一个操作以及其输入连接。NodeDef的成员如下所示:

  • name 节点的唯一标识符,该标识符不会被途中的任何其他节点使用
  • op 定义了要运行的操作,比如Add、MatMul、Conv2D
  • input 字符列表表,每个字符串都是另一个节点的名称,比如两个输入[“input_1:0”,”input_2:0”]
  • device 定义了在分布式环境中运行的位置
  • attr 包含某个节点的所有属性的键值对存储区

以上成员都可以通过 node.name node.op等来访问

因为tf在训练期间权重通常不会存储在文档格式内,而是保存在单独的检查点中,并且图中的Variable操作可在初始化操作时加载最新的值。

在部署到生产环境时,使用单独的文档往往不是很方便,因此我们需要一个脚本 freeze_graph.py

将这些检查点、文档冻结到一个文档中。

具体操作就是加载GraphDef,从最新的检查点文档中提取所有变量的值,然后将每个Variable操作替换为Const(其中包含存储在其属性中的权重的数值数据)。然后,它会剥离所有未用于前向推断的无关节点,并将生成的GraphDef保存到输出文档中。

freeze_graph.py

先了解下参数:

  • input_graph 模型文档,二进制pb或者文本meta,用input_binary来区分
  • input_saver 需要加载的Tensorflow saver文档
  • input_checkpoint 检查点文档,用于模型恢复变量值
  • checkpoint_version 变量文档的格式 (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2)
  • output_graph 冻结完成后的写入路径
  • input_binary 输入文档是否是二进制 True Or False
  • output_node_names 输出节点的名字,多个节点用逗号分隔
  • restore_op_name 已废弃
  • filename_tensor_name 已废弃
  • clear_devices 默认是True,是否清楚训练节点的设备
  • initializer_nodes 需要初始化的节点
  • variable_names_whitelist 指定需要恢复的变量
  • variable_names_blacklist 指定不用恢复的变量
  • input_meta_graph 需要加载的MetaGraphDef
  • input_saved_model_dir SavedModel文档和变量的路径
  • saved_model_tags 加载MetaGraphDef中的tag组,逗号分隔(MetaGraphDef中可以用tags来区分不同的计算图)

首先解析checkpoint版本:

1234567
if flags.checkpoint_version == 1:	checkpoint_version = saver_pb2.SaverDef.V1elif flags.checkpoint_version == 2:  checkpoint_version = saver_pb2.SaverDef.V2else:  raise ValueError("Invalid checkpoint version (must be ‘1‘ or ‘2‘): %d" %                     flags.checkpoint_version)

两种checkpoint的保存方法如下:

v1 v2
model.ckpt-0001 model.ckpt-0001.index
model.ckpt-0001.meta model.ckpt-0001.meta
model.ckpt-0001.data-00000-of-00001

然后解析输入的graphDef:

continue…

参考:

https://www.tensorflow.org/guide/extend/model_files#freezing

https://blog.csdn.net/czq7511/article/details/72452985

原文链接 大专栏  https://www.dazhuanlan.com/2019/08/24/5d612ad619bfd/

原文地址:https://www.cnblogs.com/chinatrump/p/11415213.html

时间: 2024-08-30 10:48:13

tensorflow之freeze_gragh的相关文章

在Win10 Anaconda中安装Tensorflow

有需要的朋友可以参考一下 1.安装Anaconda 下载:https://www.continuum.io/downloads,我用的是Python 3.5 下载完以后,安装. 安装完以后,打开Anaconda Prompt,输入清华的仓库镜像,更新包更快: conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ conda config --set show_channel_url

Tensorflow 梯度下降实例

# coding: utf-8 # #### 假设我们要最小化函数 $y=x^2$, 选择初始点 $x_0=5$ # #### 1. 学习率为1的时候,x在5和-5之间震荡. # In[1]: import tensorflow as tf TRAINING_STEPS = 10 LEARNING_RATE = 1 x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x") y = tf.square(x) train_op

Ubuntu16.04安装tensorflow+安装opencv+安装openslide+安装搜狗输入法

Ubuntu16.04在cuda以及cudnn安装好之后,安装tensorflow,tensorflow以及opencv可以到网上下载对应的安装包并且直接在安装包所在的路径下直接通过pip与conda进行安装,如下图所示: 前提是要下载好安装包.安装好tensorflow之后还需要进行在~/.bashrc文件中添加系统路径,如下图所示 Openslide是医学图像一个重要的库,这里给出三条命令进行安装 sudo apt-get install openslide-tools sudo apt-g

【tensorflow:Google】三、tensorflow入门

[一]计算图模型 节点是计算,边是数据流, a = tf.constant( [1., 2.] )定义的是节点,节点有属性 a.graph 取得默认计算图 g1 = tf.get_default_graph() 初始化计算图 g1 = tf.Graph() 设置default图 g1.as_default() 定义变量: tf.get_variable('v') 读取变量也是上述函数 对图指定设备 g.device('/gpu:0') 可以定义集合来管理计算图中的资源, 加入集合 tf.add_

TensorFlow之tf.unstack学习循环神经网络中用到!

unstack( value, num=None, axis=0, name='unstack' ) tf.unstack() 将给定的R维张量拆分成R-1维张量 将value根据axis分解成num个张量,返回的值是list类型,如果没有指定num则根据axis推断出! DEMO: import tensorflow as tf a = tf.constant([3,2,4,5,6]) b = tf.constant([1,6,7,8,0]) c = tf.stack([a,b],axis=0

TensorFlow conv2d实现卷积

tf.nn.conv2d是TensorFlow里面实现卷积的函数,参考文档对它的介绍并不是很详细,实际上这是搭建卷积神经网络比较核心的一个方法,非常重要 tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None) 除去name参数用以指定该操作的name,与方法有关的一共五个参数: 第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, i

Tensorflow一些常用基本概念与函数(四)

摘要:本系列主要对tf的一些常用概念与方法进行描述.本文主要针对tensorflow的模型训练Training与测试Testing等相关函数进行讲解.为'Tensorflow一些常用基本概念与函数'系列之四. 1.序言 本文所讲的内容主要为以下列表中相关函数.函数training()通过梯度下降法为最小化损失函数增加了相关的优化操作,在训练过程中,先实例化一个优化函数,比如 tf.train.GradientDescentOptimizer,并基于一定的学习率进行梯度优化训练: optimize

Tensorflow一些常用基本概念与函数(三)

摘要:本系列主要对tf的一些常用概念与方法进行描述.本文主要针对tensorflow的数据IO.图的运行等相关函数进行讲解.为'Tensorflow一些常用基本概念与函数'系列之三. 1.序言 本文所讲的内容主要为以下相关函数: 操作组 操作 Data IO (Python functions) TFRecordWrite,rtf_record_iterator Running Graphs Session management,Error classes 2.tf函数 2.1 数据IO {Da

TensorFlow【机器学习】:如何正确的掌握Google深度学习框架TensorFlow(第二代分布式机器学习系统)?

本文标签:   机器学习 TensorFlow Google深度学习框架 分布式机器学习 唐源 VGG REST   服务器 自 2015 年底开源到如今更快.更灵活.更方便的 1.0 版本正式发布,由 Google 推出的第二代分布式机器学习系统 TensorFlow一直在为我们带来惊喜,一方面是技术层面持续的迭代演进,从分布式版本.服务框架 TensorFlow Serving.上层封装 TF.Learn 到 Windows 支持.JIT 编译器 XLA.动态计算图框架 Fold 等,以及