Tensorflow学习笔记(一):MNIST机器学习入门

学习深度学习,首先从深度学习的入门MNIST入手。通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念。

一  MNIST数据集

MNIST是入门级的计算机视觉数据集,包含了各种手写数字的图片。在这个例子中就是通过机器学习训练一个模型,以识别图片中的数字。

MNIST数据集来自 http://yann.lecun.com/exdb/mnist/

Tensorflow提供了一份python代码用于自动下载安装数据集。Tensorflow官方文档中的url打不开,在CSDN上找到了一个分享:http://download.csdn.net/detail/u010417185/9588647

和官方有点不同的是,我直接把四个数据集下载下来,放在/tmp/mnist下,在项目文件中使用以下代码导入:

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

这里的数据集分为两个部分:60000的训练数据集(mnist.train)和10000的测试数据集(mnist.test),测试集的作用是帮助模型泛化。数据对应包含图片和标签,分别用mnist.train.images,mnist.train.lables,mnist.test.images,mnist.test.lables来表示。每张图片有28×28=784个像素点,因此训练图片mnist.train.images的张量表示为 [60000, 784],第一个纬度用于索引图片,第二纬度用于索引像素点。由于判断10个数字,这里采用热独,即one-hot-vectors,除了一位数字为1外其他纬度数字为0。例如判断数字为0则其表示为[1,0,0,0,0,0,0,0,0,0]。因此训练标签表示为[10000,10],第一纬度索引图片,第二纬度判断数字。

二  softmax回归介绍

softmax模型可以给不同的对象分配概率。根据下图,对输入的x的加权求和,再分别加上一个偏置量,最后输入到softmax函数中:

具体转换为公式,即:

三  实现回归模型

首先进行模型的定义,如下:

x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
W = tf.Variable(tf.zeros([784,10]))  #初始化权值
b = tf.Variable(tf.zeros([10]))      #初始化偏置值
y = tf.nn.softmax(tf.matmul(x,W) + b)  #根据公式计算

四  训练模型

选用的损失函数为交叉熵,其定义如下:

其中y为预测的概率分布,y‘为实际分布。

代码如下:

y_ = tf.placeholder("float", [None,10])  #表示实际的分布
cross_entropy = -tf.reduce_sum(y_*tf.log(y))  #计算损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #以梯度下降算法最小化损失函数
init = tf.initialize_all_variables()  #初始化所有变量
sess = tf.Session()  #定义会话
sess.run(init)   #初始化会话

for i in range(1000):   #开始训练,循环训练1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

五  评估模型

选用tf.argmax函数评估,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,用 tf.equal 来检测预测是否与真实标签匹配(索引位置一样表示匹配)。

代码如下:

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))  #评估
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))  #将结果转换为浮点数
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})  #输出

六  代码

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
W = tf.Variable(tf.zeros([784,10]))  #初始化权值
b = tf.Variable(tf.zeros([10]))      #初始化偏置值
y = tf.nn.softmax(tf.matmul(x,W) + b)  #根据公式计算
y_ = tf.placeholder("float", [None,10])  #表示实际的分布
cross_entropy = -tf.reduce_sum(y_*tf.log(y))  #计算损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #以梯度下降算法最小化损失函数
init = tf.initialize_all_variables()  #初始化所有变量
sess = tf.Session()  #定义会话
sess.run(init)   #初始化会话

for i in range(1000):   #开始训练,循环训练1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))  #评估
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))  #将结果转换为浮点数
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})  #输出

七  实验结果

最终测试结果精确度在91%左右。

时间: 2024-10-07 20:27:06

Tensorflow学习笔记(一):MNIST机器学习入门的相关文章

卷积神经网络(CNN)学习笔记1:基础入门

卷积神经网络(CNN)学习笔记1:基础入门 Posted on 2016-03-01   |   In Machine Learning  |   9 Comments  |   14935  Views 概述 卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的.CNN相较于传统的图像处理算法的优点之一在于,避免了对图像复杂的

Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节点之间则是由张量(Tensor)作为边来连接在一起的.所以Tensorflow的计算过程就是一个Tensor流图.Tensorflow的图则是必须在一个Session中来计算.这篇笔记来大致介绍一下Session.Graph.Operation和Tensor. Session Session提供了O

nodejs学习笔记之安装、入门

由于项目需要,最近开始学习nodejs.在学习过程中,记录一些必要的操作和应该注意的点. 首先是如何安装nodejs环境?(我用的是windows 7环境,所以主要是windows 7的例子.如果想看linux下的安装可以参考http://www.cnblogs.com/meteoric_cry/archive/2013/01/04/2844481.html) 1. nodejs提供了一些安装程序,可以去官网(http://nodejs.org/download/)按照自己的机器进行下载,下载完

Node.js学习笔记【1】入门(服务器JS、函数式编程、阻塞与非阻塞、回调、事件、内部和外部模块)

笔记来自<Node入门>@2011 Manuel Kiessling JavaScript与Node.js Node.js事实上既是一个运行时环境,同时又是一个库. 使用Node.js时,我们不仅仅在实现一个应用,同时还实现了整个HTTP服务器. 一个基础的HTTP服务器 server.js:一个可以工作的HTTP服务器 var http = require("http"); http.createServer(function(request, response) { r

TensorFlow学习笔记(UTF-8 问题解决 UnicodeDecodeError: &#39;utf-8&#39; codec can&#39;t decode byte 0xff in position 0: invalid start byte)

我使用VS2013  Python3.5  TensorFlow 1.3  的开发环境 UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte 在是使用Tensorflow读取图片文件的情况下,会出现这个报错 代码如下 # -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import mat

Tensorflow学习笔记(对MNIST经典例程的)的代码注释与理解

1 #coding:utf-8 2 # 日期 2017年9月4日 环境 Python 3.5  TensorFlow 1.3 win10开发环境. 3 import tensorflow as tf 4 from tensorflow.examples.tutorials.mnist import input_data 5 import os 6 7 8 # 基础的学习率 9 LEARNING_RATE_BASE = 0.8 10 11 # 学习率的衰减率 12 LEARNING_RATE_DE

tensorflow学习笔记一——just get started

我现在什么都不知道,打算开始学习tensorflow作为机器学习入门的开始. 昨天完成了对tensorflow官方入门介绍的学习,了解了tensorflow的基本原理和编程方法.我们在进行tensorflow编程时,程序的逻辑是:建立数据流图-->初始化变量-->运行程序.下面就每一步进行介绍. 建立数据流图 完成了这一部分的学习,我才了解了tensorflow的意思.在tensorflow中,程序的逻辑可以表示成数据流图,图的节点是一组对tensor(向量或者矩阵)的操作,节点的输出仍是一组

MNIST机器学习入门(一)

一.简介 首先介绍MNIST 数据集.如图1-1 所示, MNIST 数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10 类,分别对应从0-9 ,共10 个阿拉伯数字. 原始的MNIST 数据库一共包含下面4 个文件, 见表1-1 . 在表1 - 1 中,图像数据是指很多张手写字符的图像,图像的标签是指每一张图像实际对应的数字是几,也就是说,在MNIST 数据集中的每一张图像都事先标明了对应的数字.  在MNIST 数据集中有两类图像:一类是训练图像(对应文件train-images

TensFlow框架学习之MNIST机器学习入门

前言:初学TensorFlow和机器学习,MNIST算法的每条语句都不是很清楚,通过查阅资料,将每句代码的基本用法差不多理解了.希望能够帮助正在学习的你 [python] view plain copy <span style="font-size:32px;">from tensorflow.examples.tutorials.mnist  import  input_data mnist = input_data.read_data_sets("MNIST_