Tensorflow中使用CNN实现Mnist手写体识别

  本文参考Yann LeCun的LeNet5经典架构,稍加ps得到下面适用于本手写识别的cnn结构,构造一个两层卷积神经网络,神经网络的结构如下图所示:

  输入-卷积-pooling-卷积-pooling-全连接层-Dropout-Softmax输出

  

  第一层卷积利用5*5的patch,32个卷积核,可以计算出32个特征。然后进行maxpooling。第二层卷积利用5*5的patch,64个卷积核,可以计算出64个特征。然后进行max pooling。卷积核的个数是我们自己设定,可以增加卷积核数目提高分类精度,但是那样会增加更大参数,提高计算成本。

  这样输入是分辨率为28*28的图片。利用5*5的patch进行卷积。我们的卷积使用1步长(stride size),0填充模块(zero padded),这样得到的输出和输入是同一个大小。经过第一层卷积之后,卷积特征大小为28*28。然后通过ReLU函数激活。我们的pooling用简单传统的2x2大小的模板做max pooling,这样pooling后得到14*14大小的特征。经过第二层卷积后,卷积特征大小为14*14,然后通过ReLU函数激活,再经过pooling后得到特征大小为7*7。

  现在,图片尺寸减小到7x7,我们加入一个有1024个神经元的全连接层,用于处理整个图片。我们把池化层输出的张量展开成一些向量,乘上权重矩阵,加上偏置,然后对其使用ReLU。

  为了避免过拟合,在全连接层输出接上dropout层。Dropout层在训练时屏蔽一半的神经元。

1、输入数据

  直接使用tensorflow中的模块,导入输入数据:

    from tensorflow.examples.tutorials.mnist import input_data

    mnist = input_data.read_data_sets(‘MNIST_data‘, one_hot=True)

  或者使用官方提供的input_data.py文件下载mnist数据

2、启动session

  (1)交互方式启动session

    sess = tf.InteractiveSession()

  (2)一般方式启动session

    sess = tf.Session()

  ps: 使用交互方式不用提前构建计算图,而使用一般方式必须提前构建好计算图才能启动session

3、权重和偏置初始化

  权重初始化的原则:应该加入少量的噪声来打破对称性并且要避免0梯度(初始化为0)

  权重初始化一般选择均匀分布或是正态分布

  定义权重初始化方法

   def weight_variable(shape):
    #截尾正态分布,stddev是正态分布的标准偏差
    initial = tf.truncated_normal(shape=shape, stddev=0.1)
    return tf.Variable(initial)

  定义偏置初始化方法

  def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

4、定义卷积和池化方法

  TensorFlow在卷积和Pooling上有很强的灵活性。我们怎么处理边界?步长应该设多大?在这个实例里,我们的卷积使用1步长(stride size),0填充模块(zero padded),保证输出和输入是同一个大小。我们的pooling用简单传统的2x2大小的模板做maxpooling。为了代码更简洁,我们把这部分抽象成一个函数。

  def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1],  padding=‘SAME‘)
  def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=‘SAME‘)

5、直接贴完整代码

from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tf#加载数据集mnist = input_data.read_data_sets(‘MNIST_data‘, one_hot=True)

#以交互式方式启动session#如果不使用交互式session,则在启动session前必须# 构建整个计算图,才能启动该计算图sess = tf.InteractiveSession()

"""构建计算图"""#通过占位符来为输入图像和目标输出类别创建节点#shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误x = tf.placeholder("float", shape=[None, 784]) #原始输入y_ = tf.placeholder("float", shape=[None, 10]) #目标值

#为了不在建立模型的时候反复做初始化操作,# 我们定义两个函数用于初始化def weight_variable(shape):    #截尾正态分布,stddev是正态分布的标准偏差    initial = tf.truncated_normal(shape=shape, stddev=0.1)    return tf.Variable(initial)def bias_variable(shape):    initial = tf.constant(0.1, shape=shape)    return tf.Variable(initial)

#卷积核池化,步长为1,0边距def conv2d(x, W):    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding=‘SAME‘)def max_pool_2x2(x):    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],                          strides=[1, 2, 2, 1], padding=‘SAME‘)

"""第一层卷积"""#由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积#卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数W_conv1 = weight_variable([5, 5, 1, 32])#每一个输出通道都有一个偏置量b_conv1 = bias_variable([32])

#位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高#最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)x_image = tf.reshape(x, [-1, 28, 28, 1])

#第一层的卷积结果,使用Relu作为激活函数h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))#第一层卷积后的池化结果h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""W_conv2 = weight_variable([5, 5, 32, 64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""#图片尺寸减小到7*7,加入一个有1024个神经元的全连接层W_fc1 = weight_variable([7*7*64, 1024])b_fc1 = bias_variable([1024])#将最后的池化层输出张量reshape成一维向量h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])#全连接层的输出h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""#使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率#在训练的过程中启用dropout,在测试过程中关闭dropoutkeep_prob = tf.placeholder("float")h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""W_fc2 = weight_variable([1024, 10])b_fc2 = bias_variable([10])#模型预测输出y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

#交叉熵损失cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

#模型训练,使用AdamOptimizer来做梯度最速下降train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

#正确预测,得到True或False的Listcorrect_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))#将布尔值转化成浮点数,取平均值作为精确度accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

#在session中先初始化变量才能在session中调用sess.run(tf.initialize_all_variables())

#迭代优化模型for i in range(20000):    #每次取50个样本进行训练    batch = mnist.train.next_batch(50)    if i%100 == 0:        train_accuracy = accuracy.eval(feed_dict={            x: batch[0], y_:batch[1], keep_prob:1.0}) #模型中间不使用dropout        print("step %d, training accuracy %g" % (i, train_accuracy))    train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5})print("test accuracy %g" % accuracy.eval(feed_dict={            x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))6、input_data.py文件  注:python3中没有xrange,其range与python2中的xrange作用相同
#!/urs/bin/env python# -*- coding:utf-8 -*-# Copyright 2015 Google Inc. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tensorflow.python.platformimport numpyimport urllibimport tensorflow as tf

SOURCE_URL = ‘http://yann.lecun.com/exdb/mnist/‘

def maybe_download(filename, work_directory):    """Download the data from Yann‘s website, unless it‘s already here."""    if not os.path.exists(work_directory):        os.mkdir(work_directory)    filepath = os.path.join(work_directory, filename)    if not os.path.exists(filepath):        filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)        statinfo = os.stat(filepath)        print(‘Successfully downloaded‘, filename, statinfo.st_size, ‘bytes.‘)    return filepath

def _read32(bytestream):    dt = numpy.dtype(numpy.uint32).newbyteorder(‘>‘)    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]

def extract_images(filename):    """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""    print(‘Extracting‘, filename)    with gzip.open(filename) as bytestream:        magic = _read32(bytestream)        if magic != 2051:            raise ValueError(                ‘Invalid magic number %d in MNIST image file: %s‘ %                (magic, filename))        num_images = _read32(bytestream)        rows = _read32(bytestream)        cols = _read32(bytestream)        buf = bytestream.read(rows * cols * num_images)        data = numpy.frombuffer(buf, dtype=numpy.uint8)        data = data.reshape(num_images, rows, cols, 1)        return data

def dense_to_one_hot(labels_dense, num_classes=10):    """Convert class labels from scalars to one-hot vectors."""    num_labels = labels_dense.shape[0]    index_offset = numpy.arange(num_labels) * num_classes    labels_one_hot = numpy.zeros((num_labels, num_classes))    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1    return labels_one_hot

def extract_labels(filename, one_hot=False):    """Extract the labels into a 1D uint8 numpy array [index]."""    print(‘Extracting‘, filename)    with gzip.open(filename) as bytestream:        magic = _read32(bytestream)        if magic != 2049:            raise ValueError(                ‘Invalid magic number %d in MNIST label file: %s‘ %                (magic, filename))        num_items = _read32(bytestream)        buf = bytestream.read(num_items)        labels = numpy.frombuffer(buf, dtype=numpy.uint8)        if one_hot:            return dense_to_one_hot(labels)        return labels

class DataSet(object):    def __init__(self, images, labels, fake_data=False, one_hot=False,                 dtype=tf.float32):        """Construct a DataSet.        one_hot arg is used only if fake_data is true.  `dtype` can be either        `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into        `[0, 1]`.        """        dtype = tf.as_dtype(dtype).base_dtype        if dtype not in (tf.uint8, tf.float32):            raise TypeError(‘Invalid image dtype %r, expected uint8 or float32‘ %                            dtype)        if fake_data:            self._num_examples = 10000            self.one_hot = one_hot        else:            assert images.shape[0] == labels.shape[0], (                ‘images.shape: %s labels.shape: %s‘ % (images.shape,                                                       labels.shape))            self._num_examples = images.shape[0]            # Convert shape from [num examples, rows, columns, depth]            # to [num examples, rows*columns] (assuming depth == 1)            assert images.shape[3] == 1            images = images.reshape(images.shape[0],                                    images.shape[1] * images.shape[2])            if dtype == tf.float32:                # Convert from [0, 255] -> [0.0, 1.0].                images = images.astype(numpy.float32)                images = numpy.multiply(images, 1.0 / 255.0)        self._images = images        self._labels = labels        self._epochs_completed = 0        self._index_in_epoch = 0

@property    def images(self):        return self._images

@property    def labels(self):        return self._labels

@property    def num_examples(self):        return self._num_examples

@property    def epochs_completed(self):        return self._epochs_completed

def next_batch(self, batch_size, fake_data=False):        """Return the next `batch_size` examples from this data set."""        if fake_data:            fake_image = [1] * 784            if self.one_hot:                fake_label = [1] + [0] * 9            else:                fake_label = 0            return [fake_image for _ in range(batch_size)], [                fake_label for _ in range(batch_size)]        start = self._index_in_epoch        self._index_in_epoch += batch_size        if self._index_in_epoch > self._num_examples:            # Finished epoch            self._epochs_completed += 1            # Shuffle the data            perm = numpy.arange(self._num_examples)            numpy.random.shuffle(perm)            self._images = self._images[perm]            self._labels = self._labels[perm]            # Start next epoch            start = 0            self._index_in_epoch = batch_size            assert batch_size <= self._num_examples        end = self._index_in_epoch        return self._images[start:end], self._labels[start:end]def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):    class DataSets(object):        pass    data_sets = DataSets()    if fake_data:        def fake():            return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)        data_sets.train = fake()        data_sets.validation = fake()        data_sets.test = fake()        return data_sets    TRAIN_IMAGES = ‘train-images-idx3-ubyte.gz‘    TRAIN_LABELS = ‘train-labels-idx1-ubyte.gz‘    TEST_IMAGES = ‘t10k-images-idx3-ubyte.gz‘    TEST_LABELS = ‘t10k-labels-idx1-ubyte.gz‘    VALIDATION_SIZE = 5000    local_file = maybe_download(TRAIN_IMAGES, train_dir)    train_images = extract_images(local_file)    local_file = maybe_download(TRAIN_LABELS, train_dir)    train_labels = extract_labels(local_file, one_hot=one_hot)    local_file = maybe_download(TEST_IMAGES, train_dir)    test_images = extract_images(local_file)    local_file = maybe_download(TEST_LABELS, train_dir)    test_labels = extract_labels(local_file, one_hot=one_hot)    validation_images = train_images[:VALIDATION_SIZE]    validation_labels = train_labels[:VALIDATION_SIZE]    train_images = train_images[VALIDATION_SIZE:]    train_labels = train_labels[VALIDATION_SIZE:]    data_sets.train = DataSet(train_images, train_labels, dtype=dtype)    data_sets.validation = DataSet(validation_images, validation_labels,                                   dtype=dtype)    data_sets.test = DataSet(test_images, test_labels, dtype=dtype)    return data_sets
 
时间: 2024-12-21 11:13:19

Tensorflow中使用CNN实现Mnist手写体识别的相关文章

机器学习入门实践——线性回归&非线性回归&mnist手写体识别

把一本<白话深度学习与tensorflow>给啃完了,了解了一下基本的BP网络,CNN,RNN这些.感觉实际上算法本身不是特别的深奥难懂,最简单的BP网络基本上学完微积分和概率论就能搞懂,CNN引入的卷积,池化等也是数字图像处理中比较成熟的理论,RNN使用的数学工具相对而言比较高深一些,需要再深入消化消化,最近也在啃白皮书,争取从数学上把这些理论吃透 当然光学理论不太行,还是得要有一些实践的,下面是三个入门级别的,可以用来辅助对BP网络的理解 环境:win10 WSL ubuntu 18.04

深度学习-mnist手写体识别

mnist手写体识别 Mnist数据集可以从官网下载,网址: http://yann.lecun.com/exdb/mnist/ 下载下来的数据集被分成两部分:55000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test).每一个MNIST数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签.我们把这些图片设为“xs”,把这些标签设为“ys”.训练数据集和测试数据集都包含xs和ys,比如训练数据集的图片是 mnist.train.images ,训练

CNN练手——手写体识别

# 卷积层的实现函数 def convolutional_layer(input, num_input_channels, filter_size, num_filters, use_pooling=True): # 前两个参数是过滤器的尺寸,第三个参数是输入的通道,第四个参数是输出的通道,也就是过滤器的个数 shape = [filter_size, filter_size, num_input_channels, num_filters] weights = tf.Variable(tf.t

pytorch实现MNIST手写体识别(全连接神经网络)

环境: pytorch1.1 cuda9.0 ubuntu16.04 该网络有3层,第一层input layer,有784个神经元(MNIST数据集是28*28的单通道图片,故有784个神经元).第二层为hidden_layer,设置为500个神经元.最后一层是输出层,有10个神经元(10分类任务).在第二层之后还有个ReLU函数,进行非线性变换. #!/usr/bin/env python # encoding: utf-8 ''' @author: liualex @contact: [em

MNIST数据集手写体识别(CNN实现)

github博客传送门 csdn博客传送门 本章所需知识: 没有基础的请观看深度学习系列视频 tensorflow Python基础 资料下载链接: 深度学习基础网络模型(mnist手写体识别数据集) MNIST数据集手写体识别(CNN实现) import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data # 导入下载数据集手写体 mnist = input_data.read

MNIST数据集手写体识别(SEQ2SEQ实现)

github博客传送门 csdn博客传送门 本章所需知识: 没有基础的请观看深度学习系列视频 tensorflow Python基础 资料下载链接: 深度学习基础网络模型(mnist手写体识别数据集) MNIST数据集手写体识别(CNN实现) import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data # 导入下载数据集手写体 mnist = input_data.read

MNIST数据集手写体识别(MLP实现)

github博客传送门 csdn博客传送门 本章所需知识: 没有基础的请观看深度学习系列视频 tensorflow Python基础 资料下载链接: 深度学习基础网络模型(mnist手写体识别数据集) MNIST数据集手写体识别(MLP实现) import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data # 导入下载数据集手写体 mnist = input_data.read

写个神经网络,让她认得我`(?????)(Tensorflow,opencv,dlib,cnn,人脸识别)

这段时间正在学习tensorflow的卷积神经网络部分,为了对卷积神经网络能够有一个更深的了解,自己动手实现一个例程是比较好的方式,所以就选了一个这样比较有点意思的项目. 项目的github地址:github 喜欢的话就给个Star吧. 想要她认得我,就需要给她一些我的照片,让她记住我的人脸特征,为了让她区分我和其他人,还需要给她一些其他人的照片做参照,所以就需要两组数据集来让她学习,如果想让她多认识几个人,那多给她几组图片集学习就可以了.下面就开始让我们来搭建这个能认识我的"她".

tensorflow 基础学习五:MNIST手写数字识别

MNIST数据集介绍: from tensorflow.examples.tutorials.mnist import input_data # 载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据 mnist=input_data.read_data_sets('.',one_hot=True) # 打印 Training data size:55000. print("Training data size: {}".format(mnist.