代码: tensorflow/examples/tutorials/mnist/
本文的目的是来展示如何使用Tensorflow训练和评估手写数字识别问题。本文的观众是那些对使用Tensorflow进行机器学习感兴趣的人。
本文的目的并不是讲解机器学习。
请确认您已经安装了Tensorflow。
教程文件
文件 | 作用 |
mnist.py |
用来创建一个完全连接的MNIST模型。 |
fully_connected_feed.py |
使用下载的数据集训练模型。 |
运行fully_connected_feed.py文件开始训练。
python fully_connected_feed.py
准备数据
MNIST是机器学习的一个经典问题。这个问题是识别28*28像素图片上的数字,从0到9。
更多信息,请参考Yann LeCun‘s MNIST page 或者 Chris Olah‘s visualizations of MNIST。
数据下载
在run_training()方法之前,input_data.read_data_sets()方法可以让数据下载到本机训练文件夹,解压数据并返回一个DataSet实例。
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
注意:fake_data是用来进行单元测试的,读者可以忽略。
数据集 | 作用 |
data_sets.train | 55000图片和标签,用来训练。 |
data_sets.validation | 5000图片和标签,用来在迭代中校验模型准确度。 |
data_sets.test | 10000图片和标签,用来测试训练模型准确度。 |
输入和占位符
placeholder_inputs()函数创建两个tf.placeholder,用来定义输入的形状,包括fetch_size。
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS)) labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
在训练循环中,图片和标签数据集会被切分成batch_size大小,跟占位符匹配,然后通过feed_dict参数传递到sess.run()方法中。
创建图
创建占位符后,mnist.py文件中会通过三个步骤来创建图:inference(), loss()
, 和training()。
- inference() - 运行网络来进行预测。
loss()
- 用来计算损失值。training()
- 计算梯度。
inference层
inference()函数创建图,返回预测结果。
它把图片占位符当作输入,并在上面构建一对完全连接的层,使用ReLU激活后,连接一个10个节点的线性层。
每一层都位于tf.name_scope
声明的命名空间中。
with tf.name_scope(‘hidden1‘):
在该命名空间中,权重和偏置会产生tf.Variable实例,并具有所需的形状。
weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units], stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name=‘weights‘) biases = tf.Variable(tf.zeros([hidden1_units]), name=‘biases‘)
待续...
原文:《TensorFlow Mechanics 101》:https://www.tensorflow.org/get_started/mnist/mechanics