tensorflow 将训练模型保存为pd文件

前言

保存 模型有2种方法:

方法

1.使用TensorFlow模型保存函数

   save = tf.train.Saver()
   ......
   saver.save(sess,"checkpoint/model.ckpt",global_step=step)*

得到3个结果

model.ckpt-129220.data-00000-of-00001#保存了模型的所有变量的值。
model.ckpt-129220.index
model.ckpt-129220.meta  # 保存了graph结构,包括GraphDef, SaverDef等。存在时,可以不在文件中定义模型,也可以运行

再将这3个文件保存为.pd文件


import tensorflow as tf
import deeplab_model

def export_graph(model, checkpoint_dir, model_name):
    ...
    model: the defined model
    checkpoint_dir: the dir of three files
    model_name: the name of .pb
    ...
    graph = tf.Graph()
    with graph.as_default():
        ### 输入占位符
        input_img = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
        labels = tf.zeros([1, 512, 512,1])
        labels = tf.to_int32(tf.image.convert_image_dtype(labels, dtype=tf.uint8))
        ### 需要输出的Tensor
        output = model.deeplabv3_plus_model_fn(
                    input_img,
                    labels,
                    tf.estimator.ModeKeys.EVAL,
                    params={
                        'output_stride': 16,
                        'batch_size': 1,  # Batch size must be 1 because the images' size may differ
                        'base_architecture': 'resnet_v2_50',
                        'pre_trained_model': None,
                        'batch_norm_decay': None,
                        'num_classes': 2,
                        'freeze_batch_norm': True
                    }).predictions['classes']
        ### 给输出的tensor命名
        output = tf.identity(output, name='output_label')
        restore_saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        ### 初始化变量
        sess.run(tf.global_variables_initializer())
        ### load the model
        restore_saver.restore(sess, checkpoint_dir)

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, graph.as_graph_def(), [output.op.name])
        ### 将图写成.pb文件
        tf.train.write_graph(output_graph_def, 'pretrained', model_name, as_text=False)

### 调用函数,生成.pd文件
export_graph(deeplab_model, 'model/model.ckpt-133958', 'model.pd')

### 读取

import tensorflow as tf
import os

def inference():
    with tf.gfile.FastGFile('pretrained/model.pd', 'rb') as model_file:
        graph = tf.Graph()
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(model_file.read())
        [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': images},
                          return_elements=['output_label:0'],
                          name='output')
        sess = tf.Session()
        label = sess.run(output_image)
        return label
labels = inference()

2.直接保存

import tensorflow as tf
from tensorflow.python.framework import graph_util
var1 = tf.Variable(1.0, dtype=tf.float32, name='v1')
var2 = tf.Variable(2.0, dtype=tf.float32, name='v2')
var3 = tf.Variable(2.0, dtype=tf.float32, name='v3')
x = tf.placeholder(dtype=tf.float32, shape=None, name='x')
x2 = tf.placeholder(dtype=tf.float32, shape=None, name='x2')
addop = tf.add(x, x2, name='add')
addop2 = tf.add(var1, var2, name='add2')
addop3 = tf.add(var3, var2, name='add3')
initop = tf.global_variables_initializer()
model_path = './Test/model.pb'
with tf.Session() as sess:
    sess.run(initop)
    print(sess.run(addop, feed_dict={x: 12, x2: 23}))
    output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['add', 'add2', 'add3'])
    # 将计算图写入到模型文件中
    model_f = tf.gfile.FastGFile(model_path, mode="wb")
    model_f.write(output_graph_def.SerializeToString())

####读取代码:
import tensorflow as tf
with tf.Session() as sess:
    model_f = tf.gfile.FastGFile("./Test/model.pb", mode='rb')
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(model_f.read())
    c = tf.import_graph_def(graph_def, return_elements=["add2:0"])
    c2 = tf.import_graph_def(graph_def, return_elements=["add3:0"])
    x, x2, c3 = tf.import_graph_def(graph_def, return_elements=["x:0", "x2:0", "add:0"])

    print(sess.run(c))
    print(sess.run(c2))
    print(sess.run(c3, feed_dict={x: 23, x2: 2}))

原文地址:https://www.cnblogs.com/schips/p/12148020.html

时间: 2024-08-13 08:44:34

tensorflow 将训练模型保存为pd文件的相关文章

C#代码实现把网页文件保存为mht文件

MHT叫“web单一文件”.顾名思义,就是把网页中包含得图片,CSS文件以及HTML文件全部放到一个MHT文件里面.而且浏览器可以直接读取得. 由于项目需要,需实现把指定的网页文件保存为mht文件.于是到网上搜索了相关的资料.找到了一份代码.测试后通过. 现将实现过程记录如下: Step 1:项目引用文件: 安装目录/System32/cdosys.dll(c:/windows/System32/cdosys.dll),这样,将增加两个命名空间:ADODB, CDO. Step 2:放一个按钮b

将整个html内容保存到指定文件

package parser; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.MalformedURLException; import java.net.

log4j实现每个线程保存一个日志文件

log4j.properties: ### direct log messages to stdout ### log4j.appender.stdout=org.apache.log4j.ConsoleAppender log4j.appender.stdout.Target=System.out log4j.appender.stdout.layout=org.apache.log4j.PatternLayout log4j.appender.stdout.layout.Conversion

android如何保存读取读取文件文件保存到SDcard

android如何保存读取读取文件文件保存到SDcard 19. 三 / android基础 / 没有评论 本文来源于www.ifyao.com禁止转载!www.ifyao.com 上图为保存文件的方法体. 上图为如何调用方法体保存数据. 上面的截图介绍了,文件保存的基本内容. 路径也可以更改. 将内容保存到文件介绍完毕. 本文来源于www.ifyao.com禁止转载!www.ifyao.com 读取文件方法体,将方法返回值传给控件即可. 保存文件的四种操作模式 将文件保存到手机的SDcard路

Atitit.软件开发概念说明--io系统区--特殊文件名称保存最佳实践文件名称编码...filenameEncode

不个网页title保存成个个文件的时候儿有无效字符的问题... 通常两个处理方式::: replace 成个空格或者使用转义(推荐)... windows的文件名称无效字符.../\:* <>\"| 斜杠,反斜杠,冒号,星号,问号,左右的 尖括号,双引号,树杠...而且..."." 一个点和 ".."双点分别用来表示"当前目录"和"父目录", 因此它们也不能作为文件名 Linux的基本上只有反斜杠...

5.27 按步就搬 Editor如何保存和打开文件

http://www.benisoft.net/day10/index.html Eclipse通过文件后缀名来决定该文件该用哪个Editor打开,在实现org.eclipse.ui.editors 扩展点时,指定extensions为iti,这样,Eclipse碰到以.iti为文件后缀的文件,就会调用ItineraryEditor打开. Eclipse首先调用ItineraryEditor.init(...)方法.这个方法的实现一般都会调用基类的init(...)方法 来保存site.site

[ATL/WTL]_[中级]_[保存CBitmap到文件-保存屏幕内容到文件]

场景: 1. 在做图片处理时,比方放大后或加特效后须要保存CBitmap(HBITMAP)到文件. 2.截取屏幕内容到文件时. 3.不须要增加第3方库时. 说明: 这段代码部分来自网上.第一次学atl/wtl.gdi不是非常熟悉.以后转换为wtl版本号吧. 当然wtl项目直接用也没问题. 如今想想wxWidgets的wxImage类对这类操作方便多了.仅仅须要调用一个SaveFile方法. 保存HBITMAP到文件: static bool SaveBitmapToFile(CBitmap& b

用python+selenium抓取知乎今日最热和本月最热的前三个问题及每个问题的首个回答并保存至html文件

抓取知乎今日最热和本月最热的前三个问题及每个问题的首个回答,保存至html文件,该html文件的文件名应该是20160228_zhihu_today_hot.html,也就是日期+zhihu_today_hot.html 代码如下: from selenium import webdriver from time import sleep import time class ZhiHu():    def __init__(self):       self.dr = webdriver.Chr

java io流 运行错误时,保存异常到文件里面

java io流 运行错误时,保存异常到文件里面 下面这个实例,运行后,输入数字,为正确,如果输入字符串,则报错,保存错误信息 //运行错误时,保存异常到文件里面 //下面这个实例,运行后,输入数字,为正确,如果输入字符串,则报错,保存错误信息 import java.io.*; import java.util.*; public class Index{ public static void main(String[] args) throws Exception{ try{ //创建文件