两种从 TensorFlow 的 checkpoint生成 frozenpb 的方法

1. 从 ckpt-.data,ckpt-.index 和 .meta 生成 frozenpb

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "outputs"
    saver = tf.train.import_meta_graph(os.path.join(os.path.split(input_checkpoint)[0], 'graph.meta'), clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(
            # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))
        #得到当前图有几个操作节点

if __name__ == "__main__":
    # 输入ckpt模型路径
    input_checkpoint='ckpt_path/ckpt-10000'
    # 输出pb模型的路径
    out_pb_path="some_path/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

2. 从网络代码和 ckpt-.data 文件生成 frozenpb

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph

import network  # 导入网络结构

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 设置GPU
model_path = "ckpt_path/ckpt-10000"

def main():
    tf.reset_default_graph()
    input_node = tf.placeholder(
        tf.float32, shape=(None,112, 96, 3)
    )
    input_node = tf.identity(input_node,name="inputs") # 设置输入节点的名字,这里可以自定义名称
    flow = network(input_node)
    flow = tf.identity(flow, name="outs") # 设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, model_path)
        # 保存图
        tf.train.write_graph(sess.graph_def, "logdir/", "graph.pb")
        # 把图和参数结构一起
        freeze_graph.freeze_graph(
            "logdir/graph.pb", # 上面保存的图结构 graph.pb
            "",
            False,
            model_path,
            "outs",
            "save/restore_all", # 默认恢复所有
            "save/Const:0", # 默认常量
            "some_path/frozen.pb", # 保存frozen.pb
            False,
            "",
        )
    print("done")

if __name__ == "__main__":
    main()

3. 打印 网络中节点的名字

import tensorflow as tf

if __name__ == "__main__":
    checkpoint_path = '../model_fintune/ckpt-1400'
    reader = tf.train.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()  

    for key in var_to_shape_map:
        print("tensor name: ", key)
        # print(reader.get_tensor(key))

或者通过

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)

printTensors("path-to-my-pbfile.pb")

4. 两种方法对比

如果是自己的代码训练的模型,有网络结构,有 ckpt 文件,最好是使用第二种方法,使用起来很灵活,可以进行各种自定义,比如修改输入输出的节点名字,网络有多个路径的时候可以自定义输出路径。第一种方法,应该也能达到第二种方法的效果,因为它们本来就是等价的,可能会有些麻烦。第一种方法的好处就是快,不要去翻那些杂糅在一起的网络结构。

原文地址:https://www.cnblogs.com/willwell/p/12196101.html

时间: 2024-08-28 19:41:54

两种从 TensorFlow 的 checkpoint生成 frozenpb 的方法的相关文章

JAVA 中两种判断输入的是否是数字的方法__正则化_

JAVA 中两种判断输入的是否是数字的方法 package t0806; import java.io.*; import java.util.regex.*; public class zhengzehua_test { /** * @param args */ public static void main(String[] args) { // TODO Auto-generated method stub try { System.out.println("请输入第一个数字:"

两种解决IE6不支持固定定位的方法

有两种让IE6支持position:fixed1.用CSS执行表达式 *{margin:0;padding:0;} * html,* html body{ background-image:url(about:blank); background-attachment:fixed; } * html .fixed{ position:absolute; bottom:auto; top:expression(eval(document.documentElement.scrollTop+ doc

Hadoop 两种环境下的checkpoint机制

伪分布式环境: HA环境checkpoint机制 配置了HA的HDFS中,有active和standby namenode两个namenode节点.他们的内存中保存了一样的集群元数据信息,因为standby namenode已经将集群状态存储在内存中了,所以创建检查点checkpoint的过程只需要从内存中生成新的fsimage. 详细过程如下: (standby namenode=SbNN, activenamenode=ANN) 1. SBNN查看是否满足创建检查点的条件: (1) 距离上次

【Django】Django—Form两种解决表单数据无法动态刷新的方法

一.无法动态更新数据的实例 1. 如下,数据库中创建了班级表和教师表,两张表的对应关系为"多对多" 1 from django.db import models 2 3 4 class Classes(models.Model): 5 title = models.CharField(max_length=32) 6 7 8 class Teacher(models.Model): 9 name = models.CharField(max_length=32) 10 t2c = mo

两种unix网络编程线程池的设计方法

unp27章节中的27.12中,我们的子线程是通过操作共享任务缓冲区,得到task的,也就是通过线程间共享的clifd[]数组,这个数组其实就是我们的任务数组,得到其中的connfd资源. 我们对这个任务数组的操作,需要互斥量+条件变量达到同步的目的..每个线程是无规律的从clifd得到任务,然后执行的.任务和线程之间没有对应关系.线程完成本次任务之后,如果任务数组中任然有任务,则再次运行下一个任务. 而另外的一个线程池模型中,pthread_create (&temp[i].tid, NULL

php 中两种获得数据库中 数据条数的方法

一种是传统的利用mysql_num_rows()来计算 $sql="select * from news"; $res=mysql_query($sql); $number=mysql_num_rows($sql); 还有一种是利用mysql_result() $sql="select count(*) from news"; $res=mysql_query($sql); $number=mysql_result($res,0,0); mysql_result()

一道题采用两种设计模式:对比策略模式和模版方法

摘要 <C++ Primer>习题14.38和14.39分别采用策略模式.模版方法解决问题. 问题 <C++ Primer 5th>习题 14.38 : 编写一个类令其检查某个给定的 string 对象的长度是否与一个阀值相等.使用该对象编写程序,统计并报告输入的文件中长度为 1 的单词有多少个.长度为 2 的单词有多少个........长度为 10 的单词有多少个. <C++ Primer 5th>习题 14.39 : 修改上一题的程序令其报告长度在 1 至 9 之间

swift两种获取相册资源PHAsset的路径的方法(绝对路径)

方法中使用到的phasset就是我们取到的PHAsset对象 方法一: 1 let options = PHVideoRequestOptions() 2 3 options.version = PHVideoRequestOptionsVersion.current 4 5 options.deliveryMode = PHVideoRequestOptionsDeliveryMode.automatic 6 7 let manager =PHImageManager.default() 8

使用Python生成源文件的两种方法

利用Python的字符串处理模块,开发者能够编写脚本用来生成那些格式同样的C.C++.JAVA源程序.头文件和測试文件,从而避免大量的反复工作. 本文概述两种利用Python string类生成java源码的方法. 1.String Template Template是一个好东西,能够将字符串的格式固定下来,反复利用.Template也能够让开发者能够分别考虑字符串的格式和其内容了.无形中减轻了开发者的压力. Template属于string中的一个类,有两个重要的方法:substitute和s