第五章 MNIST数字识别问题(二)

4.1. ckpt文件保存方法

在对模型进行加载时候,需要定义出与原来的计算图结构完全相同的计算图,然后才能进行加载,并且不需要对定义出来的计算图进行初始化操作。 
这样保存下来的模型,会在其文件夹下生成三个文件,分别是: 
* .ckpt.meta文件,保存tensorflow模型的计算图结构。 
* .ckpt文件,保存计算图下所有变量的取值。 
* checkpoint文件,保存目录下所有模型文件列表。

import tensorflow as tf
#保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")
#加载保存了两个变量和的模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(result)

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-1.6226364]
#直接加载持久化的图。因为之前没有导出v3,所以这里会报错
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(v1)
    print sess.run(v2)
    print sess.run(v3)
INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-0.81131822]
[-0.81131822]

# 变量重命名,这样可以通过字典将模型保存时的变量名和需要加载的变量联系起来
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
saver = tf.train.Saver({"v1": v1, "v2": v2})

4.2.1 滑动平均类的保存

import tensorflow as tf
#使用滑动平均
v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables(): print variables.name

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables(): print variables.name
v:0
v:0
v/ExponentialMovingAverage:0

#保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print sess.run([v, ema.average(v)])
10.0, 0.099999905]

#加载滑动平均模型
v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999

4.2.2 variables_to_restore函数的使用样例

import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print ema.variables_to_restore()

#等同于saver = tf.train.Saver(ema.variables_to_restore())
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
{u‘v/ExponentialMovingAverage‘: <tf.Variable ‘v:0‘ shape=() dtype=float32_ref>}

4.3. pb文件保存方法

#pb文件的保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [‘add‘])
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
------------------------------------------------------------------------
#加载pb文件
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"

    with gfile.FastGFile(model_filename, ‘rb‘) as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print sess.run(result)

[array([ 3.], dtype=float32)]

张量的名称后面有:0,表示是某个计算节点的第一个输出,而计算节点本身的名称后是没有:0的。

原文地址:https://www.cnblogs.com/exciting/p/8542859.html

时间: 2024-10-01 11:27:43

第五章 MNIST数字识别问题(二)的相关文章

Thinking In Java笔记(第五章 初始化与清理(二))

第五章 初始化与清理(二) 5.5 清理:终结处理和垃圾回收 清理的工作常常被忽略,Java有垃圾回收器负责回收无用对象占据的内存资源.但也有特殊情况:假定对象(并非使用new)获得了一块"特殊"的内存区域,由于垃圾回收器只知道释放那些由new分配的内存,所以不知道如何释放特殊内存.Java允许在类中定义一个名为finalize()的方法,工作原理"假定"是这样的:一旦垃圾回收器准备好释放对象占用的存储空间,首先调用其finalize()方法,并且在下一次垃圾回收动

《Python核心编程》第五章:数字

本章大纲 介绍Python支持的多种数字类型,包括:整型.长整型.布尔型.双精度浮点型.十进制浮点型和复数.介绍和数字相关的运算符和函数. 知识点 5.1 布尔型 从Python2.3开始支持bool,取值范围:True.False 5.2 标准整型 在32位机器上,标准整数类型的取值范围:-2的31次方 ~ 2的31次方-1 - Python标准整数类型等价于C语言的(有符号)长整型. - 八进制整数以数字 "0" 开头,十六进制整数以 "0x" 或 "

TensorFlow深度学习实战---MNIST数字识别问题

1.滑动平均模型: 用途:用于控制变量的更新幅度,使得模型在训练初期参数更新较快,在接近最优值处参数更新较慢,幅度较小 方式:主要通过不断更新衰减率来控制变量的更新幅度. 衰减率计算公式 : decay = min{init_decay , (1 + num_update) / (10 + num_update)} 其中 init_decay 为设置的初始衰减率 ,num_update 为模型参数更新次数,由此可见,随着 num_update 更新次数的增加,(1 + num_update) /

【ALearning】第五章 Android相关组件介绍(二)Service

Service是Android中的四大组件之一,所以在Android开发过程中起到非常重要的作用.下面我们来看一下官方对Service的定义. A Service is an application component thatcan perform long-running operations in the background and does not provide auser interface. Another application component can start a se

Python自学:第五章 对数字列表执行简单的统计计算

>>>digits = [1,2,3,4,5,6,7,8,9,0] >>>mid(digits) 0 >>>max(digits) 9 >>>sum(digits) 45 原文地址:https://www.cnblogs.com/zhouxiin/p/10851204.html

机器学习框架Tensorflow数字识别MNIST

SoftMax回归  http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92 我们的训练集由  个已标记的样本构成: ,其中输入特征.(我们对符号的约定如下:特征向量  的维度为 ,其中  对应截距项 .) 由于 logistic 回归是针对二分类问题的,因此类标记 .假设函数(hypothesis function) 如下: 我们将训练模型参数 ,使其能够最小化代价函数 : 在 softmax回归中,我们解决的是多分

数据结构期末复习第五章数组和广义表

数据结构期末复习第五章 数组和广义表 二维数组A[m][n]按行优先 寻址计算方法,每个数组元素占据d 个地址单元.     设数组的基址为LOC(a11) :LOC(aij)=LOC(a11)+((i-1)*n+j-1)*d     设数组的基址为LOC(a00) :LOC(aij)=LOC(a00)+( i*n+j )*d    二维数组A[m][n]按列优先 寻址计算方法,每个数组元素占据d 个地址单元.     设数组的基址为LOC(a11) :LOC(aij)=LOC(a11)+((j

实战Google深度学习框架-C5-MNIST数字识别问题

5.1 MNIST数据处理 MNIST是NIST数据集的一个子集,包含60000张图片作为训练数据,10000张作为测试数据,其中每张图片代表0~9中的一个数字,图片大小为28*28(可以用一个28*28矩阵表示) 为了清楚表示,用下图14*14矩阵表示了,其实应该是28*28矩阵 TF提供了一个类来处理MNIST数据: 准备工作:桌面新建MNIST数字识别->cd MNIST数字识别->shift + 右键->在此处新建命令窗口->jupyter notebook->新建g

tensorflow 基础学习五:MNIST手写数字识别

MNIST数据集介绍: from tensorflow.examples.tutorials.mnist import input_data # 载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据 mnist=input_data.read_data_sets('.',one_hot=True) # 打印 Training data size:55000. print("Training data size: {}".format(mnist.