tf.train.Saver()

1. 实例化对象

saver = tf.train.Saver(max_to_keep=1)

max_to_keep: 表明保存的最大checkpoint文件数。当一个新文件创建的时候,旧文件就会被删掉。如果值为None或0, 表示保存所有的checkpoint文件。默认值5(也就是说,保存最近的5个checkpoint文件)。

keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep_checkpoint文件,你还可能想每训练N小时保存一个checkpoint文件。这将是非常有用的,如果你想分析一个模型在很长的一段训练时间内是怎么改变的。例如,设置keep_checkpoint_every_n_hour=2确保每训练2个小时保存一个checkpoint文件。

2. 保存训练过程中或者训练好的, 模型图及权重参数

2.1 创建完saver对象后,就可以保存训练好的模型了

  

saver.save(sess=sess, save_path=model_save_path, global_step=step)

第一个参数sess=sess, 会话名字;

第二个参数save_path=model_save_path, 设定权重参数保存到的路径和文件名;

第三个参数global_step=step, 将训练的次数作为后缀加入到模型名字中。

2.2 一次saver.save()后可以在文件夹中看到新增的四个文件

实际上没调用一次保存操作会创建后3个数据文件并创建一个检查点(checkpoint)文件。

  • 简单理解就是权重等参数被保存到.ckpt.data文件中,以字典的形式;
  • ckpt-index, 应该是内部需要的某种索引来正确映射前两个文件;
  • 图和元数据被保存到.ckpt.meta文件中,可以使用tf.train.import_meta_graph加载

3. 重载模型的图及权重参数

重载模型的参数,继续训练或用于测试数据

saver.restore(sess=sess, save_path = model_save_path)
  • 第一个参数sess=sess, 会话名字
  • 第二个参数save_path=model_save_path, 权重参数的保存路径和文件名

原连接:https://blog.csdn.net/liuxiaodong400/article/details/83421164

原文地址:https://www.cnblogs.com/elitphil/p/12048395.html

时间: 2024-07-31 17:56:43

tf.train.Saver()的相关文章

跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflow as tf #创建变量 v1 = tf.Variable(tf.random_normal([1, 2], name='v1')) v2 = tf.Variable(tf.random_normal([2, 3], name='v2')) #初始化变量 init_op = tf.global_v

图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2ra

tensorflow-训练检查点tf.train.Saver

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Thu Sep 6 10:16:37 2018 @author: myhaspl @email:[email protected] """ import tensorflow as tf g1=tf.Graph() with g1.as_default(): with tf.name_scope("input_Va

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积.池化等)都需要operator,保存和恢复操作也不例外. 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph.所以,下列类初始化操作应在搭建Graph时完成. saver = tf.train.Saver()

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess

TF:利用TF的train.Saver载入曾经训练好的variables(W、b)以供预测新的数据

import tensorflow as tf import numpy as np W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights") b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases") saver = tf.train.Saver() with tf.

TF:利用TF的train.Saver将训练好的variables(W、b)保存到指定的index、meda文件

import tensorflow as tf import numpy as np W = tf.Variable([[2,1,8],[1,2,5]], dtype=tf.float32, name='weights') b = tf.Variable([[1,2,5]], dtype=tf.float32, name='biases') init= tf.global_variables_initializer() saver = tf.train.Saver() with tf.Sessi

tensorflow API _ 3 (tf.train.polynomial_decay)

学习率的三种调整方式:固定的,指数的,多项式的 def _configure_learning_rate(num_samples_per_epoch, global_step): """Configures the learning rate. Args: num_samples_per_epoch: The number of samples in each epoch of training. global_step: The global_step tensor. Re