LSTM用于MNIST手写数字图片分类

按照惯例,先贴代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

# 输入图片是28*28
n_inputs = 28 #输入一行,一行有28个数据
max_time = 28 #一共28行
lstm_size = 100 #隐层单元
n_classes = 10 # 10个分类
batch_size = 50 #每批次50个样本
n_batch = mnist.train.num_examples // batch_size #计算一共有多少个批次

#这里的none表示第一个维度可以是任意的长度
x = tf.placeholder(tf.float32,[None,784])
#正确的标签
y = tf.placeholder(tf.float32,[None,10])

#初始化权值
weights = tf.Variable(tf.truncated_normal([lstm_size, n_classes], stddev=0.1))
#初始化偏置值
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))

#定义RNN网络
def RNN(X,weights,biases):
    # inputs=[batch_size, max_time, n_inputs]
    inputs = tf.reshape(X,[-1,max_time,n_inputs])
    #定义LSTM基本CELL
    lstm_cell = tf.nn.rnn_cell.LSTMCell(lstm_size)
    #lstm_cell = tf.contrib.rnn.LSTMCell(lstm_size, name=‘basic_lstm_cell‘)
    # final_state[0]是cell state
    # final_state[1]是hidden_state
    outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
    results = tf.nn.softmax(tf.matmul(final_state[1],weights) + biases)
    return results

#计算RNN的返回结果
prediction= RNN(x, weights, biases)
#损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#把correct_prediction变为float32类型
#初始化
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(50):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print ("Iter " + str(epoch) + ", Testing Accuracy= " + str(acc))

原文地址:https://www.cnblogs.com/yqpy/p/11227922.html

时间: 2024-10-05 09:55:48

LSTM用于MNIST手写数字图片分类的相关文章

MNIST手写数字图片识别(线性回归、CNN方法的手工及框架实现)(未完待续)

0-Background 作为Deep Learning中的Hello World 项目无论如何都要做一遍的. 代码地址:Github 练习过程中将持续更新blog及代码. 第一次写博客,很多地方可能语言组织不清,请多多提出意见..谢谢~ 0.1 背景知识: Linear regression CNN LeNet-5 AlexNet ResNet VGG 各种regularization方式 0.2 Catalog 1-Prepare 2-MNIST 3-LinearRegression 1-P

一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&locationNum=5 Tensorflow官方英文文档地址:https://www.tensorflow.org/get_started/mnist/beginners 本文整理时官方文档最近更新时间:2017年2月15日 1.案例背景 本文是跟着Tensorflow官方文档的第二篇教程–识别手

Tensorflow实践 mnist手写数字识别

minst数据集                                         tensorflow的文档中就自带了mnist手写数字识别的例子,是一个很经典也比较简单的入门tensorflow的例子,非常值得自己动手亲自实践一下.由于我用的不是tensorflow中自带的mnist数据集,而是从kaggle的网站下载下来的,数据集有些不太一样,所以直接按照tensorflow官方文档上的参数训练的话还是踩了一些坑,特此记录. 首先从kaggle网站下载mnist数据集,一份是

tensorflow 基础学习五:MNIST手写数字识别

MNIST数据集介绍: from tensorflow.examples.tutorials.mnist import input_data # 载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据 mnist=input_data.read_data_sets('.',one_hot=True) # 打印 Training data size:55000. print("Training data size: {}".format(mnist.

基于MNIST手写数字数据集的数字识别小程序

30行代码奉上!(MNIST手写数字的识别,识别率大约在91%,简单尝试的一个程序,小玩具而已) 1 import tensorflow.examples.tutorials.mnist.input_data as input_data 2 import tensorflow as tf 3 mnist = input_data.read_data_sets('/temp/', one_hot=True) 4 5 #设置 6 x = tf.placeholder(tf.float32,[None

简单HOG+SVM mnist手写数字分类

使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images-idx3-ubyte.gz:  training set images (9912422 bytes) train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) t10k-images-idx3-ubyte.gz:   test s

MNIST手写数字数据库

手写数字库很容易建立,但是总会很浪费时间.Google实验室的Corinna Cortes和纽约大学柯朗研究所的Yann LeCun建有一个手写数字数据库,训练库有60,000张手写数字图像,测试库有10,000张. 请访问原站 http://yann.lecun.com/exdb/mnist/ 该数据库在一个文件中包含了所有图像,使用起来有所不便.如果我把每个图像分别保存,成了图像各自独立的数据库. 并在Google Code中托管. 如果你有需要,欢迎在此下载: http://yann.le

Pytorch入门实战一:LeNet神经网络实现 MNIST手写数字识别

记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表的一片Twitter,调侃道:l've been using PyTorch a few months now, l've never felt better, l've more energy.My skin is clearer. My eye sight has improved.确实,使用p

MNIST手写数字分类simple版(03-2)

simple版本nn模型 训练手写数字处理 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist=input_data.read_data_sets("MNIST_data", one_hot=True) #每个批次的大小 batch_size=100 #计算一共有多少批次 n_batch=mnist.train.num_examples // ba