tensorflow-训练检查点tf.train.Saver


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
g1=tf.Graph()

with g1.as_default():
with tf.name_scope("input_Variable"):
my_var=tf.Variable(1,dtype=tf.float32)
with tf.name_scope("global_step"):
my_step=tf.Variable(0,dtype=tf.int32)
with tf.name_scope("update"):
varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
stepop=tf.assign_add(my_step,1)
addop=tf.group([varop,stepop])
with tf.name_scope("summaries"):
tf.summary.scalar(‘myvar‘,my_var)
with tf.name_scope("global_ops"):
init=tf.global_variables_initializer()
merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:
writer=tf.summary.FileWriter(‘sum_vars‘,sess.graph)
sess.run(init)
#---0
step,var,summary=sess.run([my_step,my_var,merged_summaries])
writer.add_summary(summary,global_step=step)
print step,var
saver=tf.train.Saver()
#1-49
for i in xrange(1,50):
sess.run(addop)
step,var,summary=sess.run([my_step,my_var,merged_summaries])
writer.add_summary(summary,global_step=step)
print step,var
if i%5==0:
saver.save(sess,‘./myvar-model/myvar-model‘,global_step=i)
saver.save(sess,‘./myvar-model/myvar-model‘,global_step=49)

writer.flush()
writer.close()

38 0.0512373
39 0.04996785
40 0.048759546
41 0.04760808
42 0.04650955
43 0.045460388
44 0.04445735
45 0.04349747
46 0.042578023
47 0.041696515
48 0.040850647
49 0.04003831

保存数据流图的变量到二进制检查点文件。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
import os
g1=tf.Graph()

with g1.as_default():
    with tf.name_scope("input_Variable"):
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32,trainable=False)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar(‘myvar‘,my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:
    writer=tf.summary.FileWriter(‘sum_vars‘,sess.graph)
    sess.run(init)

    saver=tf.train.Saver()

    #如果之前保存了检查点文件,则恢复模型后,继续
    init_step=0
    ckpt=tf.train.get_checkpoint_state(os.getcwd()+‘/myvar-model‘)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        init_step=int(ckpt.model_checkpoint_path.rsplit(‘-‘,1)[1])
        print "读取检查点文件..."
    for i in xrange(init_step,100):
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var,init_step
        if i%5==0 and i<=50:
            print "保存检查点文件"
            saver.save(sess,‘./myvar-model/myvar-model‘,global_step=i)
        sess.run(addop)

    writer.flush()
    writer.close()

上面代码跑第一次时,检查点文件被保存,跑第二次开始,检查点文件将被读取,循环次数从step=50开始。

跑第二次时

读取检查点文件...
50 0.03925755 50
保存检查点文件
51 0.038506564 50
52 0.037783686 50
53 0.03708737 50
54 0.036416177 50
55 0.035768777 50
56 0.03514393 50
...
...
...
93 0.021334965 50
94 0.02111056 50
95 0.02089082 50
96 0.0206756 50
97 0.020464761 50
98 0.020258171 50
99 0.020055704 50

原文地址:http://blog.51cto.com/13959448/2326699

时间: 2024-07-31 17:56:42

tensorflow-训练检查点tf.train.Saver的相关文章

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess

图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2ra

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积.池化等)都需要operator,保存和恢复操作也不例外. 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph.所以,下列类初始化操作应在搭建Graph时完成. saver = tf.train.Saver()

跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflow as tf #创建变量 v1 = tf.Variable(tf.random_normal([1, 2], name='v1')) v2 = tf.Variable(tf.random_normal([2, 3], name='v2')) #初始化变量 init_op = tf.global_v

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

tf.train.Saver()

1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删掉.如果值为None或0, 表示保存所有的checkpoint文件.默认值5(也就是说,保存最近的5个checkpoint文件). keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep_checkpoint文件,你还可能想每训练N小时保存一个checkpo

TensorFlow 中的 tf.train.exponential_decay() 指数衰减法

exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None) 使用方式为 tf.train.exponential_decay( ) 在 Tensorflow 中,exponential_decay()是应用于学习率的指数衰减函数. 在训练模型时,通常建议随着训练的进行逐步降低学习率.该函数需要`global_step`值来计算衰减的学习速率. 该函数返回衰减后

TF:利用TF的train.Saver载入曾经训练好的variables(W、b)以供预测新的数据

import tensorflow as tf import numpy as np W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights") b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases") saver = tf.train.Saver() with tf.

TF:利用TF的train.Saver将训练好的variables(W、b)保存到指定的index、meda文件

import tensorflow as tf import numpy as np W = tf.Variable([[2,1,8],[1,2,5]], dtype=tf.float32, name='weights') b = tf.Variable([[1,2,5]], dtype=tf.float32, name='biases') init= tf.global_variables_initializer() saver = tf.train.Saver() with tf.Sessi