tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型

使用步骤:

  • 实例化一个Saver对象 saver = tf.train.Saver()
  • 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step)
  • 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess,FLAGG.train_dir)

在save之后会在相应的路径下面新增如下四个红色文件

在saver实例每次调用save方法时,都会创建三个数据文件和一个检查点(checkpoint)文件,权重等参数被以字典的形式保存到.ckpt.data中,图和元数据被保存到.ckpt.meta中,可以被tf.train.import_meta_graph加载到当前默认的图

softmaxRegression.py

 1 # _*_ coding:utf-8 _*_
 2 import os
 3 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
 4 import tensorflow as tf
 5 from tensorflow.examples.tutorials.mnist import input_data
 6
 7 #get the datase
 8 mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
 9
10 print(mnist.train.images.shape,mnist.train.labels.shape)
11
12 sess = tf.InteractiveSession()
13
14 x = tf.placeholder(tf.float32,[None,784])
15 W = tf.Variable(tf.zeros([784,10]))
16 b = tf.Variable(tf.zeros([10]))
17
18 y = tf.nn.softmax(tf.matmul(x,W)+b)
19 y_ = tf.placeholder(tf.float32,[None,10])
20 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
21
22 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
23 tf.global_variables_initializer().run()
24
25 #保存模型对象saver
26 saver = tf.train.Saver()
27
28 #判断保存模型对象文件夹是否存在
29 if not os.path.exists(‘tmp/‘):
30     print(‘i am here‘)
31     os.mkdir(‘tmp/‘)
32 else:
33     print("2")
34
35
36 if os.path.exists(‘tmp/chckpoint‘):
37     saver.restore(sess,‘tmp/model.ckpt‘)
38     correct_prediction = tf.equal(tf.arg_max(y, 1), tf.argmax(y_, 1))
39     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
40     save_path = saver.save(sess, ‘tmp/model.ckpt‘)
41     print("2")
42     print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
43 else:
44     for i in range(1000):
45         batch_xs,batch_ys = mnist.train.next_batch(100)
46         train_step.run({x:batch_xs,y_:batch_ys})
47     correct_prediction = tf.equal(tf.arg_max(y,1),tf.argmax(y_,1))
48     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
49     save_path = saver.save(sess,‘tmp/model.ckpt‘)
50     print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))

原文地址:https://www.cnblogs.com/bevishe/p/10359993.html

时间: 2024-11-09 13:45:46

tf.train.Saver()-tensorflow中模型的保存及读取的相关文章

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积.池化等)都需要operator,保存和恢复操作也不例外. 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph.所以,下列类初始化操作应在搭建Graph时完成. saver = tf.train.Saver()

tf.train.Saver()

1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删掉.如果值为None或0, 表示保存所有的checkpoint文件.默认值5(也就是说,保存最近的5个checkpoint文件). keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep_checkpoint文件,你还可能想每训练N小时保存一个checkpo

跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflow as tf #创建变量 v1 = tf.Variable(tf.random_normal([1, 2], name='v1')) v2 = tf.Variable(tf.random_normal([2, 3], name='v2')) #初始化变量 init_op = tf.global_v

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2ra

tensorflow 之模型的保存与加载(一)

怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. 1 #!/usr/bin/env python3 2 #-*- coding:utf-8 -*- 3 ############################ 4 #File Name: saver.py 5 #Brief: 6 #Author: frank 7 #Mail: [email protected] 8 #Created Time:2018-06-22 22:12:52 9 ###

tensorflow-训练检查点tf.train.Saver

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Thu Sep 6 10:16:37 2018 @author: myhaspl @email:[email protected] """ import tensorflow as tf g1=tf.Graph() with g1.as_default(): with tf.name_scope("input_Va

Tensorflow Learning1 模型的保存和恢复

CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢复图和变量,继续做训练 2. pb : 将图序列化,变量成为固定的值,,只可以做inference:不能继续训练 Demo 1 def freeze_graph(input_checkpoint,output_graph): 2 3 ''' 4 :param input_checkpoint: 5 :para

【tf.keras】tf.keras使用tensorflow中定义的optimizer

我的 tensorflow+keras 版本: print(tf.VERSION) # '1.10.0' print(tf.keras.__version__) # '2.1.6-tf' tf.keras 没有实现 AdamW,即 Adam with Weight decay.论文<DECOUPLED WEIGHT DECAY REGULARIZATION>提出,在使用 Adam 时,weight decay 不等于 L2 regularization.具体可以参见 当前训练神经网络最快的方式