Tensorflow训练和预测中的BN层的坑

  以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。

  书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。

  那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。

  所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:

 1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
 2 with tf.control_dependencies(update_ops):
 3         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
 4
 5 # 设置保存模型
 6 var_list = tf.trainable_variables()
 7 g_list = tf.global_variables()
 8 bn_moving_vars = [g for g in g_list if ‘moving_mean‘ in g.name]
 9 bn_moving_vars += [g for g in g_list if ‘moving_variance‘ in g.name]
10 var_list += bn_moving_vars
11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。

最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。

感谢几位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

原文地址:https://www.cnblogs.com/baijing1/p/9842321.html

时间: 2024-10-21 02:36:07

Tensorflow训练和预测中的BN层的坑的相关文章

tensorflow在文本处理中的使用——Word2Vec预测

代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)--第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-cookbook 数据:http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz 问题:加载和使用预训练的嵌套,并使用这些单词嵌套进行情感分析,通过训练线性逻辑回归模型来预测电影的好坏 步骤如下: 必要包

tensorflow 2.0 学习 (十一)卷积神经网络 (一)MNIST数据集训练与预测 LeNet-5网络

网络结构如下: 代码如下: 1 # encoding: utf-8 2 3 import tensorflow as tf 4 from tensorflow import keras 5 from tensorflow.keras import layers, Sequential, losses, optimizers, datasets 6 import matplotlib.pyplot as plt 7 8 Epoch = 30 9 path = r'G:\2019\python\mn

【卷积神经网络】对BN层的解释

前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift 个人觉得BN层的作用是加快网络学习速率,论文中提及其它的优点都是这个优点的副产品. 网上对BN解释详细的不多,大多从原理上解释,没有说出实际使用的过程,这里从what, why, how三个角度去解释BN. What is

读论文《BP改进算法在哮喘症状-证型分类预测中的应用》

总结: 一.研究内容 本文研究了CAL-BP(基于隐层的竞争学习与学习率的自适应的改进BP算法)在症状证型分类预测中的应用. 二.算法思想 1.隐层计算完各节点的误差后,对有最大误差的节点的权值进行正常修正,  而对其它单元的权值都向相反方向修正,用 δ表示隐层节点的权值修正量, 则修正量的调整公式具体为 2.每次算法迭代完以后,计算误差函数的值并与前一次的值进行比较,如果误差函数的值增大,     则代表过调了学习率,应在下一次迭代时以一定比率下调学习率 ],若误差函数的i+1值减小,    

tensorflow在文本处理中的使用——skip-gram模型

代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)--第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-cookbook 数据来源:http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz 理解互相关联的单词:king - man + woman = queen 如果已知man和woman语义相关联,那我们可

tensorflow在文本处理中的使用——CBOW词嵌入模型

代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)--第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-cookbook 数据:http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz CBOW概念图: 步骤如下: 必要包 声明模型参数 读取数据集 创建单词字典,转换句子列表为单词索引列表 生成批量数据 构建

TensorFlow训练MNIST数据集(3) —— 卷积神经网络

前面两篇随笔实现的单层神经网络 和多层神经网络, 在MNIST测试集上的正确率分别约为90%和96%.在换用多层神经网络后,正确率已有很大的提升.这次将采用卷积神经网络继续进行测试. 1.模型基本结构 如下图所示,本次采用的模型共有8层(包含dropout层).其中卷积层和池化层各有两层. 在整个模型中,输入层负责数据输入:卷积层负责提取图片的特征:池化层采用最大池化的方式,突出主要特征,并减少参数维度:全连接层再将个特征组合起来:dropout层可以减少每次训练的计算量,并可以一定程度上避免过

Tensorflow训练识别自定义图片

很多正在入门或刚入门TensorFlow机器学习的同学希望能够通过自己指定图片源对模型进行训练,然后识别和分类自己指定的图片.但是,在TensorFlow官方入门教程中,并无明确给出如何把自定义数据输入训练模型的方法.现在,我们就参考官方入门课程<Deep MNIST for Experts>一节的内容(传送门:https://www.tensorflow.org/get_started/mnist/pros),介绍如何将自定义图片输入到TensorFlow的训练模型. 在<Deep M

Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测

3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算法接口的逻辑就需要考虑一下并行化的问题了.我们平均一个训练任务在一天处理的数据量大约在500G左右,虽然数据规模不是特别的庞大,但是Python算法包提供的算法都是单进程执行.我们计算过,如果使用一台机器训练全部品类数据需要一个星期的时间,这是无法接收的,所以我们需要借助Spark这种分布式并行计算