Tensorflow学习笔记(对MNIST经典例程的)的代码注释与理解

1 #coding:utf-8
  2 # 日期 2017年9月4日 环境 Python 3.5  TensorFlow 1.3 win10开发环境。
  3 import tensorflow as tf
  4 from tensorflow.examples.tutorials.mnist import input_data
  5 import os
  6
  7
  8 # 基础的学习率
  9 LEARNING_RATE_BASE = 0.8
 10
 11 # 学习率的衰减率
 12 LEARNING_RATE_DECAY = 0.99
 13
 14 # 描述模型复杂度的正则化项在损失函数中的系数
 15 REGULARIZATION_RATE = 0.0001
 16
 17 # 训练轮数
 18 TRAINING_STEPS = 30000
 19
 20 # 滑动平均衰减率
 21 MOVING_AVERAGE_DECAY = 0.99
 22
 23 # 模型持久化保存路径
 24 MODEL_SAVE_PATH = "MNIST_model/"
 25 # 模型持久化保存文件名称
 26 MODEL_NAME = "mnist_model"
 27
 28 # 输入层节点数(对于数据集,相当于整个图片的像素数目)
 29 INPUT_NODE = 784
 30
 31 # 输出层的节点数(根据10个数字决定的)
 32 OUTPUT_NODE = 10
 33
 34 # 隐藏层的节点数,此例程中,隐藏层为一层。
 35 LAYER1_NODE = 500
 36
 37 # 一个训练batch中的训练数据个数,数字越小的时候,训练过程越接近随机梯度下降。
 38 BATCH_SIZE = 100
 39
 40
 41 def train(mnist):
 42     # 定义输入输出placeholder。
 43     x = tf.placeholder(tf.float32, [None, INPUT_NODE], name=‘x-input‘)
 44     y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name=‘y-input‘)
 45     # 正则化损失函数
 46     regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 47     # 使用定义的向前传播过程
 48     y = inference(x, regularizer)
 49
 50     # 定义存储训练轮数的变量。这个变量不需要计算滑动的平均值,所以这里指定这个变量为不可训练的变量(trainable=False)。
 51     # 在tensorflow中训练神经网络的时候,一般会将代表训练轮数的变量指定为不可训练的参数。
 52     global_step = tf.Variable(0, trainable=False)
 53
 54     # 定义损失函数、学习率、滑动平均操作以及训练过程。
 55     variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
 56     # 在所有代表神经网络参数的变量上使用滑动平均。其它辅助变量(如global_step)就不需要了
 57     variables_averages_op = variable_averages.apply(tf.trainable_variables())
 58     # 计算交叉熵作为刻画预测值和真实值之间差距的损失函数。(第一个参数是神经网络不包含softmax层的前向传播结果,第二个是训练数据的正确答案)
 59     # 因为标准答案是一个长度为10的一维数组,二该函数需要提供的是一个正确答案的数字,所以需要使用tf.argmax函数来得到正确答案对应的类别编号。
 60     cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
 61     # 计算当前batch中所有样例的交叉熵平均值
 62     cross_entropy_mean = tf.reduce_mean(cross_entropy)
 63     # 总损失等于交叉熵和
 64     loss = cross_entropy_mean + tf.add_n(tf.get_collection(‘losses‘))
 65
 66     # 设置指数衰减的学习率
 67     learning_rate = tf.train.exponential_decay(
 68         LEARNING_RATE_BASE,                                             # 基础的学习率,随着迭代的进行,更新变量时使用的学习率在这个基础上递减
 69         global_step,                                                    # 当前迭代的轮数
 70         mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,     # 过完所有的训练数据需要的迭代次数
 71         staircase=True)
 72
 73     #  使用tf.train.GradientDescentOptimizer优化算法来优化损失函数。注意这里损失函数包含了交叉熵和正则损失
 74     train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
 75
 76     with tf.control_dependencies([train_step, variables_averages_op]):
 77         train_op = tf.no_op(name=‘train‘)
 78
 79     # 初始化TensorFlow持久化类。
 80     saver = tf.train.Saver()
 81     with tf.Session() as sess:
 82         tf.global_variables_initializer().run()
 83
 84         # 在训练过程中,不在测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成。
 85         for i in range(TRAINING_STEPS):
 86             xs, ys = mnist.train.next_batch(BATCH_SIZE)
 87             _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
 88             if i % 1000 == 0:
 89                 # 输出当前的训练情况,这里只输出了模型在当前训练batch上的损失函数大小,通过损失函数的大小可以大概了解训练的情况。在验证数据集上的正确
 90                 # 率信息会有一个单独的程序来生成。
 91                 print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
 92                 # 保存当前的模型。global_step参数,这样可以让每个被保存模型的文件名末尾加上训练的轮数,如model.ckpt-1000表示训练1000轮之后得到的模型
 93                 saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
 94
 95 # 通过tf.get_variable函数来获取变量 在测试是会通过保存的模型加载这些变量的取值。而且更加方便的是,因为可以在变量加载时将滑动平均变量重命名
 96 # 所以可以直接通过同样的名字在训练时使用变量自身,而在测试时使用变量的滑动平均值。这个函数中会将变量的正则化损失加损失集合。
 97 def get_weight_variable(shape, regularizer):
 98     weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
 99     # 当给出正则化生产函数时,将当前变量的正则化损失加入名字为Losses的集合。在这里使用了add_to_collection函数将一个张量加入一个集合,
100     # 而这个集合的名称为losses.这是自定义集合,不在Tensorflow自动管理的集合列表中
101     if regularizer != None: tf.add_to_collection(‘losses‘, regularizer(weights))
102     return weights
103
104 # 定义神经网络的前向传播过程(初始化所有参数的辅助函数,给定神经网络中的参数)
105 def inference(input_tensor, regularizer):
106     # 声明第一层神经网络的变量并完成前向传播过程
107     with tf.variable_scope(‘layer1‘):
108         # 通过tf.get_variable 和tf.Variable没有本质区别,因为在训练或是测试中没有在同一个程序中多次调用这个函数。如果在同一个过程多次调用,
109         # 在第一调用的之后需要将resuse参数设置为True
110         weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
111         biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
112         layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
113
114     # 声明第二层神经网络的变量并完成向前传播的过程
115     with tf.variable_scope(‘layer2‘):
116         weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
117         biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
118         layer2 = tf.matmul(layer1, weights) + biases
119
120     return layer2
121
122 # 2.主程序部分
123 def main(argv=None):
124     # 获取数据集(根据谷歌的例程中相关的获取路径)
125     mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
126     # 根据数据集训练模型
127     train(mnist)
128
129 # 1 .程序入口
130 if __name__ == ‘__main__‘:
131     main()

对Tensorflow中经典的MNIST模型的学习,程序整个过程进行了注释,摘自《实战google深度学习框架》中代码,并进行修改后注释。

时间: 2024-10-12 13:07:58

Tensorflow学习笔记(对MNIST经典例程的)的代码注释与理解的相关文章

Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节点之间则是由张量(Tensor)作为边来连接在一起的.所以Tensorflow的计算过程就是一个Tensor流图.Tensorflow的图则是必须在一个Session中来计算.这篇笔记来大致介绍一下Session.Graph.Operation和Tensor. Session Session提供了O

TensorFlow学习笔记(UTF-8 问题解决 UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte)

我使用VS2013  Python3.5  TensorFlow 1.3  的开发环境 UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte 在是使用Tensorflow读取图片文件的情况下,会出现这个报错 代码如下 # -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import mat

Swift学习笔记(一)搭配环境以及代码运行成功

原文:Swift学习笔记(一)搭配环境以及代码运行成功 1.Swift是啥? 百度去!度娘告诉你它是苹果最新推出的编程语言,比c,c++,objc要高效简单.能够开发ios,mac相关的app哦!是苹果以后大力推广的语言哦! 2.Swift给你带来什么机会? 当初你觉得objc太难,学ios学到一半放弃拉,或者进入it行业大家都搞android,你也搞android去了.现在你终于有机会和搞ios的站在一个语言的起跑线上,兄弟!swift传说很容易学哦,搞android的你想不想增加一下本领?提

Node.js学习笔记【3】NodeJS基础、代码的组织和部署、文件操作、网络操作、进程管理、异步编程

一.表 学生表 CREATE TABLE `t_student` ( `stuNum` int(11) NOT NULL auto_increment, `stuName` varchar(20) default NULL, `birthday` date default NULL, PRIMARY KEY  (`stuNum`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 学生分数表 CREATE TABLE `t_stu_score` ( `id` int(11

Tensorflow学习笔记(一):MNIST机器学习入门

学习深度学习,首先从深度学习的入门MNIST入手.通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念. 一  MNIST数据集 MNIST是入门级的计算机视觉数据集,包含了各种手写数字的图片.在这个例子中就是通过机器学习训练一个模型,以识别图片中的数字. MNIST数据集来自 http://yann.lecun.com/exdb/mnist/ Tensorflow提供了一份python代码用于自动下载安装数据集.Tensorflow官方文档中的url打不开,在CSDN上找到了一

TensorFlow学习笔记(6)读取数据

Overview    之前几次推送的全部例程,使用的都是tensorflow预处理过的数据集,直接载入即可.例如: 然而实际中我们使用的通常不会是这种超级经典的数据集,如果我们有一组图像存储在磁盘上面,如何以mini-batch的形式把它们读取进来然后高效的送进网络训练?这次推送我们首先用tensorflow最底层的API处理这个问题,后面推送介绍高层API.高层API是对底层的进一步封装,用户可以不必关心过多细节.不过了解一下比较底层的API还是有好处的.当你有一组自己的数据的时候,你需要经

Google TensorFlow 学习笔记一 —— TensorFlow简介

"TensorFlow is an Open Source Software Library for Machine INtenlligence" 本笔记参考tensorflow.org的教程,翻译并记录作者的学习过程,仅供参考,如有不当之处,请及时指出并多多包涵. TensorFlow是一款开源的数学计算软件,使用data flow graphs的形式进行计算.这种灵活的架构允许我们使用相同的API在单或多CPUs或GPU,servers设置移动设备上进行计算. Data Flow

tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=Tru

Tensorflow学习笔记3:TensorBoard可视化学习

TensorBoard简介 Tensorflow发布包中提供了TensorBoard,用于展示Tensorflow任务在计算过程中的Graph.定量指标图以及附加数据.大致的效果如下所示, TensorBoard工作机制 TensorBoard 通过读取 TensorFlow 的事件文件来运行.TensorFlow 的事件文件包括了你会在 TensorFlow 运行中涉及到的主要数据.关于TensorBoard的详细介绍请参考TensorBoard:可视化学习.下面做个简单介绍. Tensorf