tensorflow中moving average的正确用法

一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些。

tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我自己的做法:

1.构建训练模型时,添加如下代码

1 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
2 variables_averages_op = variable_averages.apply(tf.trainable_variables())
3 ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()]
4 train_op = tf.group(train_op, variables_averages_op)

第1行创建了一个指数移动平均类 variable_averages

第2行将variable_averages作用于当前模型中所有可训练的变量上,得到 variables_averages_op操作符

第3行获得所有可训练变量对应的移动平均变量列表集合,后续用于保存模型

第4行在原有的训练操作符基础上,再添加variables_averages_op操作符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量做移动平均

2.开始训练前,创建saver时,使用如下代码

1 save_vars = tf.trainable_variables() + ave_vars2 saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

第1行获取所有需要保存的变量列表,这个时候 ave_vars就派上用场了。

第2行创建saver,指定var_list为所有可训练变量及其对应的移动平均变量。

另外需要注意的是,如果你的模型中有bn或者类似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还需要额外添加进save_vars中,可以参考我的这篇博客

3.在做inference的时候,利用如下代码从checkpoint中恢复出移动平均模型

1 variable_averages = tf.train.ExponentialMovingAverage(0.999)
2 variables_to_restore = variable_averages.variables_to_restore()
3 saver = tf.train.Saver(variables_to_restore)
4 saver.restore(sess, model_path)

这几行很简单,就不做解释了。

实际上,在inference的时候,刚刚的做法除了可以从checkpoint文件中恢复出移动平均参数,还可以恢复出对应迭代的模型参数,可以用来对比两种方式,哪种效果更好,这时只需要将上面代码的第3行改为saver = tf.train.Saver(tf.trainable_variables())即可(和保存时相同,如果有bn,也需要额外考虑)。在我的测试中,使用移动平均参数效果更佳。

时间: 2024-11-05 22:03:04

tensorflow中moving average的正确用法的相关文章

【转】改善C#程序的建议2:C#中dynamic的正确用法 空间

dynamic是FrameWork4.0的新特性.dynamic的出现让C#具有了弱语言类型的特性.编译器在编译的时候不再对类型进行检查,编译期默认dynamic对象支持你想要的任何特性.比如,即使你对GetDynamicObject方法返回的对象一无所知,你也可以像如下那样进行代码的调用,编译器不会报错: dynamic dynamicObject = GetDynamicObject();Console.WriteLine(dynamicObject.Name);Console.WriteL

【转】C#中dynamic的正确用法

原文:http://www.cnblogs.com/qiuweiguo/archive/2011/08/03/2125982.html dynamic是FrameWork4.0的新特性.dynamic的出现让C#具有了弱语言类型的特性.编译器在编译的时候不再对类型进行检查,编译期默认dynamic对象支持你想要的任何特性.比如,即使你对GetDynamicObject方法返回的对象一无所知,你也可以像如下那样进行代码的调用,编译器不会报错: dynamic dynamicObject = Get

[LeetCode] Moving Average from Data Stream 从数据流中移动平均值

Given a stream of integers and a window size, calculate the moving average of all integers in the sliding window. For example,MovingAverage m = new MovingAverage(3);m.next(1) = 1m.next(10) = (1 + 10) / 2m.next(3) = (1 + 10 + 3) / 3m.next(5) = (10 + 3

C#中dynamic的正确用法 以及 typeof(DynamicSample).GetMethod("Add");

dynamic是FrameWork4.0的新特性.dynamic的出现让C#具有了弱语言类型的特性.编译器在编译的时候不再对类型进行检查,编译期默认dynamic对象支持你想要的任何特性.比如,即使你对GetDynamicObject方法返回的对象一无所知,你也可以像如下那样进行代码的调用,编译器不会报错: dynamic dynamicObject = GetDynamicObject(); Console.WriteLine(dynamicObject.Name); Console.Writ

Spring MVC中Session的正确用法<转>

Spring MVC是个非常优秀的框架,其优秀之处继承自Spring本身依赖注入(Dependency Injection)的强大的模块化和可配置性,其设计处处透露着易用性.可复用性与易集成性.优良的设计模式遍及各处,使得其框架虽然学习曲线陡峭,但 一旦掌握则欲罢不能.初学者并不需要过多了解框架的实现原理,随便搜一下如何使用“基于注解的controller”就能很快上手,而一些书籍诸如 “spring in action”也给上手提供了非常优良的选择. 网上的帖子多如牛毛,中文的快速上手,英文的

C#中dynamic的正确用法

dynamic是FrameWork4.0的新特性.dynamic的出现让C#具有了弱语言类型的特性.编译器在编译的时候不再对类型进行检查,编译期默认dynamic对象支持你想要的任何特性.比如,即使你对GetDynamicObject方法返回的对象一无所知,你也可以像如下那样进行代码的调用,编译器不会报错: dynamic dynamicObject = GetDynamicObject();Console.WriteLine(dynamicObject.Name);Console.WriteL

转载~kxcfzyk:Linux C语言多线程库Pthread中条件变量的的正确用法逐步详解

Linux C语言多线程库Pthread中条件变量的的正确用法逐步详解 多线程c语言linuxsemaphore条件变量 (本文的读者定位是了解Pthread常用多线程API和Pthread互斥锁,但是对条件变量完全不知道或者不完全了解的人群.如果您对这些都没什么概念,可能需要先了解一些基础知识) 关于条件变量典型的实际应用,可以参考非常精简的Linux线程池实现(一)——使用互斥锁和条件变量,但如果对条件变量不熟悉最好先看完本文. Pthread库的条件变量机制的主要API有三个: int p

改善C#程序的建议2:C#中dynamic的正确用法

原文:改善C#程序的建议2:C#中dynamic的正确用法 dynamic是FrameWork4.0的新特性.dynamic的出现让C#具有了弱语言类型的特性.编译器在编译的时候不再对类型进行检查,编译期默认dynamic对象支持你想要的任何特性.比如,即使你对GetDynamicObject方法返回的对象一无所知,你也可以像如下那样进行代码的调用,编译器不会报错: dynamic dynamicObject = GetDynamicObject(); Console.WriteLine(dyn

LeetCode 346. Moving Average from Data Stream (数据流动中的移动平均值)

Given a stream of integers and a window size, calculate the moving average of all integers in the sliding window. For example, MovingAverage m = new MovingAverage(3); m.next(1) = 1 m.next(10) = (1 + 10) / 2 m.next(3) = (1 + 10 + 3) / 3 m.next(5) = (1