主要了解下freeze_graph的用法
以及了解下freeze_graph_test的一些相关知识(据说具有很好的学习价值)
freeze_graph.py源码链接:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py
freeze_graph_test.py源码链接:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py
tf模型的基本介绍
Tensorflow所有的文档格式都是基于Protocol Buffer,即protobuf
在文本文档中定义数据结构,protobuf工具生成C、Python和其他语言的类,这些类可以友好的加载、保存和方位数据
Tensorflow里的计算基础是Graph对象
它可以存储网络节点,每一个节点代表一个操作,并作为输入和输出相互链接在一起
GraphDef 类是ProtoBuf根据
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
以此为基础定义创建的对象。Protobuf工具会解析此文本文档,并生成用户加载、存储和操控图定义的代码。
将文档加载到 grapf_def 变量中,就可以访问其中的数据
可以使用下面的代码来遍历这些节点,基本上重要的部分都是存储在节点中了。
12 |
graph_def = graph_pb2.GraphDef()for node in graph_def.node |
每一个节点都是一个在node_def.proto定义的NodeDef对象,这些节点是Tensorflow图的基本构建块,每一个构建块都定义了一个操作以及其输入连接。NodeDef的成员如下所示:
- name 节点的唯一标识符,该标识符不会被途中的任何其他节点使用
- op 定义了要运行的操作,比如Add、MatMul、Conv2D
- input 字符列表表,每个字符串都是另一个节点的名称,比如两个输入[“input_1:0”,”input_2:0”]
- device 定义了在分布式环境中运行的位置
- attr 包含某个节点的所有属性的键值对存储区
以上成员都可以通过 node.name node.op等来访问
因为tf在训练期间权重通常不会存储在文档格式内,而是保存在单独的检查点中,并且图中的Variable操作可在初始化操作时加载最新的值。
在部署到生产环境时,使用单独的文档往往不是很方便,因此我们需要一个脚本 freeze_graph.py
将这些检查点、文档冻结到一个文档中。
具体操作就是加载GraphDef,从最新的检查点文档中提取所有变量的值,然后将每个Variable操作替换为Const(其中包含存储在其属性中的权重的数值数据)。然后,它会剥离所有未用于前向推断的无关节点,并将生成的GraphDef保存到输出文档中。
freeze_graph.py
先了解下参数:
- input_graph 模型文档,二进制pb或者文本meta,用input_binary来区分
- input_saver 需要加载的Tensorflow saver文档
- input_checkpoint 检查点文档,用于模型恢复变量值
- checkpoint_version 变量文档的格式 (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2)
- output_graph 冻结完成后的写入路径
- input_binary 输入文档是否是二进制 True Or False
- output_node_names 输出节点的名字,多个节点用逗号分隔
- restore_op_name 已废弃
- filename_tensor_name 已废弃
- clear_devices 默认是True,是否清楚训练节点的设备
- initializer_nodes 需要初始化的节点
- variable_names_whitelist 指定需要恢复的变量
- variable_names_blacklist 指定不用恢复的变量
- input_meta_graph 需要加载的MetaGraphDef
- input_saved_model_dir SavedModel文档和变量的路径
- saved_model_tags 加载MetaGraphDef中的tag组,逗号分隔(MetaGraphDef中可以用tags来区分不同的计算图)
首先解析checkpoint版本:
1234567 |
if flags.checkpoint_version == 1: checkpoint_version = saver_pb2.SaverDef.V1elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2else: raise ValueError("Invalid checkpoint version (must be ‘1‘ or ‘2‘): %d" % flags.checkpoint_version) |
两种checkpoint的保存方法如下:
v1 | v2 |
---|---|
model.ckpt-0001 | model.ckpt-0001.index |
model.ckpt-0001.meta | model.ckpt-0001.meta |
model.ckpt-0001.data-00000-of-00001 |
然后解析输入的graphDef:
continue…
参考:
https://www.tensorflow.org/guide/extend/model_files#freezing
https://blog.csdn.net/czq7511/article/details/72452985
原文链接 大专栏 https://www.dazhuanlan.com/2019/08/24/5d612ad619bfd/
原文地址:https://www.cnblogs.com/chinatrump/p/11415213.html