TensorFlow基础入门(四)

注意:本部分的ppt来源于中国大学mooc网站:https://www.icourse163.org/learn/ZUCC-1206146808?tid=1206445215&from=study#/learn/content?type=detail&id=1211168244&cid=1213754001

#MNIST手写数字识别数据集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt

mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
#了解MNIST手写数字识别数据集
print("训练集train数量:",mnist.train.num_examples,
      ",验证集 validation数量:",mnist.validation.num_examples,
      ",测试集 test 数量:",mnist.test.num_examples)
print("train image shape:",mnist.train.images.shape,
      "labels shape:",mnist.train.labels.shape)

全部源码:

#MNIST手写数字识别数据集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
import os
#读取相关的数据
mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
#定义待输入数据的占位符
#mnist中每张图片共有28*28=784个像素点
x=tf.placeholder(tf.float32,[None,784],name="X")
#0-9一共10个数字====》10个类别
y=tf.placeholder(tf.float32,[None,10],name="y")
#定义模型变量
‘‘‘
在本案例中,以正态分布的随机数初始化权重W,以常数0初始化偏置b
‘‘‘
#定义变量
w=tf.Variable(tf.random_normal([784,10]),name="w")
b=tf.Variable(tf.zeros([10]),name="b")
#用单个神经元构建神经网络
forward=tf.matmul(x,w)+b#前向计算
pred=tf.nn.softmax(forward)#softmax分类
#设置训练参数
train_epochs=100#训练轮数
batch_size=100#单次训练样本数(批次大小)
total_batch=int(mnist.train.num_examples/batch_size)#一轮训练有多少批次
display_step=1#显示粒度
learning_rate=0.01#学习率
#定义损失函数(定义交叉商的损失函数)
loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
#梯度下降优化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
#检查预测类别tf.argmax(ored,1)与实际类别tf.argmax(y,1)的匹配情况
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
#准确率,将布尔值转化为浮点数,并计算平均值
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

sess=tf.Session()#声明会话
init=tf.global_variables_initializer()#变量初始化
sess.run(init)

#训练模型的保存
#储存模型的粒子
save_step=5
#创建保存模型文件的目录
ckpt_dir="./ckpt_dir/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

#声明完所有 变量之后,使用tf.train.Saver()
saver=tf.train.Saver()

#模型训练
#开始训练
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys=mnist.train.next_batch(batch_size)#读取批次数据
        sess.run(optimizer,feed_dict={x:xs,y:ys})#执行批次训练
    #total_batch个批次训练完成后,使用验证数据计算误差与准确率:验证没有分批
    loss,acc=sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})

    #打印训练过程中的详细信息
    if(epoch+1)%display_step==0:
        print("Train Epoch:",‘%02d‘%(epoch+1),"Loss=","{:.9}".format(loss),"Accuracy=","{:.4f}".format(acc))
    if(epoch+1)%save_step==0:
        saver.save(sess,os.path.join(ckpt_dir,‘mnist_h256_model_{:06d}.ckpt‘.format(epoch+1)))
        print(‘mnist_h256_model_{:06d}.ckpt‘.format(epoch+1))
#对训练的模型进行保存
saver.save(sess,os.path.join(ckpt_dir,‘mnist_h256_model_ckpt‘))
print("Train Finished")

#评估模型
#完成训练之后,在测试集上评估模型的准确率
def accu_test():
    accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
    print("Test Accuracy:",accu_test)

def acc_validation():
    #完成训练之后在验证集上评估模型的准确率
    acc_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
    print("Validation Accuracy:",acc_validation)

def acc_train():
    #完成训练之后,在训练集上评估模型的准确率
    acc_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
    print("Train Accuracy:",acc_train)

#定义数据可视化
def plot_image_labels_prediction(images,labels,prediction,index,num=10):
    ‘‘‘
    image:图像列表
    labels:标签列表
    prediction:预测值列表
    index:从第index个开始显示
    num:一次显示多少副图片,缺省的话一次显示10个
    ‘‘‘
    fig=plt.gcf()#获取当前图表,Get Current Figure
    fig.set_size_inches(10,12)#1英寸等于1.54cm
    if num>25:
        num=25#设置最多显示25个子图
    for i in range(0,num):
        ax=plt.subplot(5,5,i+1)#获取当前要处理的子图
        ax.imshow(np.reshape(images[index],(28,28)),cmap="binary")
        title="label="+str(np.argmax(labels[index]))#构建该图上要显示的title信息
        if len(prediction)>0:
            title+=",predict="+str(prediction[index])

        ax.set_title(title,fontsize=10)#显示图上的title信息
        ax.set_xticks([])#不显示坐标轴
        ax.set_yticks([])
        index+=1
    plt.show()

独热标码:

#MNIST手写数字识别数据集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt

mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
#独热编码如何取值
print(mnist.train.labels[1])
#argmax()取出独热编码中最大值的下标
print(np.argmax(mnist.train.labels[1]))

原文地址:https://www.cnblogs.com/byczyz/p/12079660.html

时间: 2024-10-08 17:07:08

TensorFlow基础入门(四)的相关文章

C#基础入门 四

C#基础入门 四 方法参数 值参数:不附加任何修饰符: 输出参数:以out修饰符声明,可以返回一个或多个给调用者: 如果想要一个方法返回多个值,可以用输出参数来处理,输出参数由out关键字标识,如static void Car(out int x,out int y,int z){},与引用参数区别在于:调用方法前无需对输出参数进行初始化,输出型参数用于传递方法返回的数值. 计算矩形面积的方法:(图8) static void rectangle(int length,int width, ou

Python基础入门 (四)

一.迭代器&生成器 1.迭代器仅仅是一容器对象,它实现了迭代器协议.它有两个基本方法: 1)next 方法 返回容器的下一个元素 2)_iter_方法 返回迭代器自身.迭代器可以使用内建的iter方法创建 ts = iter(['asd','sds','qweq']) #创建iter方法 print(ts.__next__()) #使用_next_方法返回下一个元素 print(ts.__next__()) print(ts.__next__()) #运行结果 asd sds qweq#需要注意

TensorFlow基础入门(六)--基础总结(TensorFlow框架基础)

本节课目标:搭建神经网络,总结搭建八股 一.基本概念 基于TensorFlow的NN:用张量表示数据,用计算图搭建神经网络,用会话执行计算图,优化线上的权重(参数),得到模型 张量:张量就是多维数组(列表),用阶表示张量的维度. 0阶称为标量,表示一个单独的数                        举例S=123 1阶张量称作向量,表示一个一维数组              举例 V=[1,2,3] 2阶张量称作矩阵,表示一个二维数组,它可以有i行j列个元素,每个元素可以用行号和列号共同

React.js 基础入门四--要点总结

JSX语法,像是在Javascript代码里直接写XML的语法,实质上这只是一个语法糖,每一个XML标签都会被JSX转换工具转换成纯Javascript代码,React 官方推荐使用JSX, 当然你想直接使用纯Javascript代码写也是可以的,只是使用JSX,组件的结构和组件之间的关系看上去更加清晰. 1. HTML 标签 和 React 组件 在JS中写HTML标签,也许小伙伴们都惊呆了,那么React又是怎么区分HTML标签,React组件标签? HTML标签: var myDivEle

芝麻HTTP:TensorFlow基础入门

本篇内容基于 Python3 TensorFlow 1.4 版本. 本节内容 本节通过最简单的示例 -- 平面拟合来说明 TensorFlow 的基本用法. 构造数据 TensorFlow 的引入方式是: ?import tensorflow as tf 接下来我们构造一些随机的三维数据,然后用 TensorFlow 找到平面去拟合它,首先我们用 Numpy 生成随机三维点,其中变量 x 代表三维点的 (x, y) 坐标,是一个 2×100 的矩阵,即 100 个 (x, y),然后变量 y 代

TensorFlow基础入门(五)--单隐层与双隐层的神经网络结构

注意:本部分的ppt来源于中国大学mooc网站:https://www.icourse163.org/learn/ZUCC-1206146808?tid=1206445215&from=study#/learn/content?type=detail&id=1211168244&cid=1213754001 原文地址:https://www.cnblogs.com/byczyz/p/12079731.html

Android基础入门教程——8.3.7 Paint API之—— Xfermode与PorterDuff详解(四)

Android基础入门教程--8.3.7 Paint API之-- Xfermode与PorterDuff详解(四) 标签(空格分隔): Android基础入门教程 本节引言: 上节我们写了关于Xfermode与PorterDuff使用的第一个例子:圆角&圆形图片ImageView的实现, 我们体会到了PorterDuff.Mode.DST_IN给我们带来的好处,本节我们继续来写例子练练手, 还记得Android基础入门教程--8.3.2 绘图类实战示例给大家带来的拔掉美女衣服的实现吗? 当时我

FPGA基础入门篇(四) 边沿检测电路

FPGA基础入门篇(四)--边沿检测电路 一.边沿检测 边沿检测,就是检测输入信号,或者FPGA内部逻辑信号的跳变,即上升沿或者下降沿的检测.在检测到所需要的边沿后产生一个高电平的脉冲.这在FPGA电路设计中相当的广泛. 没有复位的情况下,正常的工作流程如下: (1)D触发器经过时钟clk的触发,输出trigger信号,保存了t0时刻的信号. (2)同时由trigger通过非门输出信号,保留了当前时刻t1的触发信号 (3)经过与门输出信号pos_edge,neg_edge a) 只有t0时刻为高

Linux从入门到放弃、零基础入门Linux(第四篇):在虚拟机vmware中安装centos7.7

如果是新手,建议安装带图形化界面的centos,这里以安装centos7.7的64位为例 一.下载系统镜像 镜像文件下载链接https://wiki.centos.org/Download 阿里云官网:https://mirrors.aliyun.com 现更新为:https://opsx.alibaba.com/mirror 清华软件镜像:https://mirrors.tuna.tsinghua.edu.cn/ 都可以, 下载centos7.7的64位版本镜像文件种子,然后用下载软件下载即可