TensorFlow实现Softmax回归(模型存储与加载)

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Thu Oct 18 18:02:26 2018
 4
 5 @author: zhen
 6 """
 7
 8 from tensorflow.examples.tutorials.mnist import input_data
 9 import tensorflow as tf
10
11 # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
12 my_mnist = input_data.read_data_sets("C:/Users/zhen/MNIST_data_bak/", one_hot=True)
13
14 # The MNIST data is split into three parts:
15 # 55,000 data points of training data (mnist.train)
16 # 10,000 points of test data (mnist.test), and
17 # 5,000 points of validation data (mnist.validation).
18
19 # Each image is 28 pixels by 28 pixels
20
21 # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
22 # 所以输入的矩阵是None乘以784二维矩阵
23 x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
24 # 初始化都是0,二维矩阵784乘以10个W值
25 W = tf.Variable(tf.zeros([784, 10]))
26 b = tf.Variable(tf.zeros([10]))
27
28 y = tf.nn.softmax(tf.matmul(x, W) + b)
29
30 # 训练
31 # labels是每张图片都对应一个one-hot的10个值的向量
32 y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
33 # 定义损失函数,交叉熵损失函数
34 # 对于多分类问题,通常使用交叉熵损失函数
35 # reduction_indices等价于axis,指明按照每行加,还是按照每列加
36 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
37                                               reduction_indices=[1]))
38 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
39
40 # 评估
41
42 # tf.argmax()是一个从tensor中寻找最大值的序号,tf.argmax就是求各个预测的数字中概率最大的那一个
43
44 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
45
46 # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
47 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
48
49 # 初始化变量
50 sess = tf.InteractiveSession()
51 tf.global_variables_initializer().run()
52 # 创建Saver节点,用于保存训练的模型
53 saver = tf.train.Saver()
54 for i in range(100):
55     batch_xs, batch_ys = my_mnist.train.next_batch(100)
56     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
57     # 每隔一段时间保存一次中间结果
58     if i % 10 == 0:
59         save_path = saver.save(sess, "C:/Users/zhen/MNIST_data_bak/saver/softmax_middle_model.ckpt")
60
61     # print("TrainSet batch acc : %s " % accuracy.eval({x: batch_xs, y_: batch_ys}))
62     # print("ValidSet acc : %s" % accuracy.eval({x: my_mnist.validation.images, y_: my_mnist.validation.labels}))
63
64 # 测试
65 print("TestSet acc : %s" % accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
66 # 保存最终的模型
67 save_path = saver.save(sess, "C:/Users/zhen/MNIST_data_bak/saver/softmax_final_model.ckpt")
68
69 # 使用训练好的模型直接进行预测
70 with tf.Session() as sess_back:
71     saver.restore(sess_back, "C:/Users/zhen/MNIST_data_bak/saver/softmax_final_model.ckpt")
72     # 评估
73     correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
74     accruary = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
75     # 测试
76     print(accuracy.eval({x : my_mnist.test.images, y_ : my_mnist.test.labels}))
77 # 总结
78 # 1,定义算法公式,也就是神经网络forward时的计算
79 # 2,定义loss,选定优化器,并指定优化器优化loss
80 # 3,迭代地对数据进行训练
81 # 4,在测试集或验证集上对准确率进行评测

结果:

  

解析:

  把训练好的模型存储落地磁盘,有利于多次使用和共享,也便于当训练出现异常时能恢复模型而不是重新训练!

原文地址:https://www.cnblogs.com/yszd/p/9822365.html

时间: 2024-12-24 18:28:50

TensorFlow实现Softmax回归(模型存储与加载)的相关文章

转 tensorflow模型保存 与 加载

使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in

linux平台学x86汇编(十二):字符串的存储与加载

[版权声明:尊重原创,转载请保留出处:blog.csdn.net/shallnet,文章仅供学习交流,请勿用于商业用途] 字符串的存储与加载是指,将字符串的值加载到寄存器和将其传回内存位置中.其使用指令lods指令和stos指令. lods指令用于把内存中的字符串值传送到eax寄存器中,该指令有三种不同格式:lodsb(1字节).lodsw(2字节).lodsl(4字节).lods指令使用esi寄存器作为隐含的源操作数.esi寄存器必须包含要加载的字符串所在的内存地址. 在使用lods指令把字符

TensorFlow的模型保存与加载

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf #tensorboard --logdir="./" def linearregression(): with tf.variable_scope("original_data"): X = tf.random_normal([100,1],mean=0.0,stddev=1.0) y_true = tf.matmul

字典转模型和懒加载

1.字典转模型 创建一个类,继承自NSObject,属性名和字典的键一致 可以实现字典转模型 @implementation TZMessage +(instancetype)messageWithDict(NSDictioary*)dict{ TZMessage *mode = [[TZMessage alloc] init]; [mode setValuesForKeysWithDictonary:dict]; return mode; } 2.懒加载 +(NSArray *)message

第 17 章 存储与加载本地文件

请参考教材,全面理解和完成本章节内容... ... 复制工程ch16,将工程目录改名为ch17. 在手机上完全退出你的"陋习手记"App(不是把应用隐藏起来),再重新执行"陋习手记"App,哇!我的之前的手记哪里去了? 几乎所有应用都需要有个地方存储数据.本章,我们将升级CriminalIntent应用,实现保存并加载存储在设备上的JSON文件数据. Android设备上的所有应用都有一个放置在沙盒中的文件目录.将文件保存在沙盒中可阻止其他应用的访问.甚至是其他用户

Tensorflow学习第1课——从本地加载MNIST以及FashionMNIST数据

很多Tensorflow第一课的教程都是使用MNIST或者FashionMNIST数据集作为示例数据集,但是其给的例程基本都是从网络上用load_data函数直接加载,该函数封装程度比较高,如果网络出现问题,数据集很难实时从网上下载(笔者就多次遇到这种问题,忍无可忍),而且数据是如何解码的也一无所知,不利于后续的学习和理解,因此本文主要介绍对下载到本地的MNIST或FashionMNIST数据集如何加载解析的问题. 下载到本地的数据集一般有两种格式:numpy的压缩格式.npz,以及gzip压缩

C++数据文件存储与加载(利用opencv)

首先请先确认已经安装好了opencv3及以上版本. #include <opencv2/opencv.hpp>#include <iostream>#include <string>using namespace cv;using namespace std;存储then int main(){//创造一些要存的数据先 string words = "hello, my guys!"; float n = 3.1415926; Mat m = Mat

15 数据在安卓设备上的存储,加载

存储和加载: public void ConnectToSqlite (string DBName) { //判断名字是否规范,如果不规范就加上后缀 if (!DBName.Contains (".sqlite")) { DBName += ".sqlite"; } //如果运行在编辑器中 #if UNITY_EDITOR //获取路径 sqlitePath = "Data Source =" + Application.streamingAss

tensorflow 之模型的保存与加载(一)

怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. 1 #!/usr/bin/env python3 2 #-*- coding:utf-8 -*- 3 ############################ 4 #File Name: saver.py 5 #Brief: 6 #Author: frank 7 #Mail: [email protected] 8 #Created Time:2018-06-22 22:12:52 9 ###