TensorFlow Saver的使用方法

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

  • Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
  • 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
  • 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

示例代码:

import tensorflow as tf
import numpy as np
from six.moves import xrange

x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 2

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b

loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

#isTrain = True
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ‘test/‘

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if isTrain:
        for i in xrange(train_steps):
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0:
                saver.save(sess, checkpoint_dir + ‘model.ckpt‘, global_step=i + 1)
    else:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print(sess.run(w))
        print(sess.run(b))

        y_result = sess.run(y_predict, feed_dict={x: np.reshape(4, (1, 1))})
        print(y_result)

2.1 训练阶段

使用Saver.save()方法保存模型:

  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + ‘model.ckpt‘:表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

2.2测试阶段

测试阶段使用saver.restore()方法恢复变量:

  1. sess:表示当前会话,之前保存的结果将被加载入这个会话
  2. ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

时间: 2024-08-29 15:52:16

TensorFlow Saver的使用方法的相关文章

调用tensorflow中的concat方法时Expected int32, got list containing Tensors of type '_Message' instead.

grid = tf.concat(0, [x_t_flat, y_t_flat, ones])#报错语句 grid = tf.concat( [x_t_flat, y_t_flat, ones],0) #楼主改后的代码 将数字放在后面,如果有三个参数 decoder_inputs = tf.concat([go_inputs, decoder_inputs_tmp], 1,name="dec_in") 调用tensorflow中的concat方法时Expected int32, got

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

TensorFlow 常用函数与方法

摘要:本文主要对tf的一些常用概念与方法进行描述. tf函数 TensorFlow 将图形定义转换成分布式执行的操作, 以充分利用可用的计算资源(如 CPU 或 GPU.一般你不需要显式指定使用 CPU 还是 GPU, TensorFlow 能自动检测.如果检测到 GPU, TensorFlow 会尽可能地利用找到的第一个 GPU 来执行操作. 并行计算能让代价大的算法计算加速执行,TensorFlow也在实现上对复杂操作进行了有效的改进.大部分核相关的操作都是设备相关的实现,比如GPU.下面是

TensorFlow之Varibale 使用方法

------------------------------------------- 转载请注明: 来自博客园 xiuyuxuanchen 地址:http://www.cnblogs.com/greentomlee/ ------------------------------------------- Varibale 使用方法 实例: 实例讲解: 首先: #!/usr/bin/env python 这句话是指定python的运行环境,这种指定方式有两种,一种是指定python的路径---#

吴裕雄 python 神经网络——TensorFlow 数据集基本使用方法

import tempfile import tensorflow as tf input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_tensor_slices(input_data) # 定义迭代器. iterator = dataset.make_one_shot_iterator() # get_next() 返回代表一个输入数据的张量. x = iterator.get_next() y = x * x with tf.S

TensorFlow——共享变量的使用方法

1.共享变量用途 在构建模型时,需要使用tf.Variable来创建一个变量(也可以理解成节点).当两个模型一起训练时,一个模型需要使用其他模型创建的变量,比如,对抗网络中的生成器和判别器.如果使用tf.Variable,将会生成一个新的变量,而我们需要使用原来的那个变量.这时就是通过引入get_Variable方法,实现共享变量来解决这个问题.这种方法可以使用多套网络模型来训练一套权重. 2.使用get_Variable获取变量 get_Variable一般会配合Variable_scope一

TensorFlow的数据读取方法

Tensorflow一共提供了3种读取数据的方法:第一种方法个人感觉比较麻烦:在TensorFlow程序运行的每一步, 让Python代码来供给数据,比如说用PIL和numpy处理数据,然后输入给神经网络.第二种方法:从文件读取数据,在TensorFlow图的起始, 让一个输入管线从文件中读取数据:string_input_producer()和slice_input_producer(). 他们两者区别可以简单理解为:string_input_producer每次取出一个文件名.slice_i

windows安装tensorflow简单直接的方法(win10+pycharm+tensorflow-gpu1.7+cuda9.1+cudnn7.1)

安装tensorflow-gpu环境需要:python环境,tensorflow-gpu包,cuda,cudnn 一,安装python,pip3直接到官网下载就好了,下载并安装你喜欢的版本 https://www.python.org/ 提示:安装最后一步时记得勾选添加环境变量 在cmd输入pip3测试pip3能否使用,不能使用的话,手动打开python安装的路径,找到pip3文件,将路径加入环境变量 二,安装tensorflow-gpu 使用pip3安装即可:pip3 install tens

tensorflow冻结层的方法

其实常说的fine tune就是冻结网络前面的层,然后训练最后一层.那么在tensorflow里如何实现finetune功能呢?或者说是如何实现冻结部分层,只训练某几个层呢?可以通过只选择优化特定层的参数来实现该功能. 示例代码如下: #定义优化算子 optimizer = tf.train.AdamOptimizer(1e-3) #选择待优化的参数 output_vars = tf.get_collection(tf.GraphKyes.TRAINABLE_VARIABLES, scope='