机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

1. tf.train.Saver()

  • tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
  • TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。
  • 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()

TensorFlow的保存和恢复分为两种:

  • 保存和恢复变量
  • 保存和恢复模型

saver.save()保存模型

#举例:
保存一个训练好的手写数据集识别模型
保存在当前路径的net文件夹中

 1 import os
 2 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
 3 import tensorflow as tf
 4 from tensorflow.examples.tutorials.mnist import input_data
 5
 6 #载入数据集
 7 mnist = input_data.read_data_sets(‘MNIST_data‘, one_hot=True)
 8
 9 #每个批次100张照片
10 batch_size = 100
11 #计算一个需要多少个批次
12 n_batch = mnist.train.num_examples // batch_size
13
14 #定义两个placeholder
15 x = tf.placeholder(tf.float32, [None, 784])
16 y = tf.placeholder(tf.float32, [None, 10])
17
18 #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
19 W = tf.Variable(tf.zeros([784, 10]))
20 b = tf.Variable(tf.zeros([10]))
21 prediction = tf.nn.softmax(tf.matmul(x, W) + b)
22 #代价函数
23 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
24 #使用梯度下降法
25 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
26
27 #初始化变量
28 init = tf.global_variables_initializer()
29
30 #结果存放在一个布尔型列表中
31 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
32
33 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
34
35 saver = tf.train.Saver()
36
37 with tf.Session() as sess:
38     sess.run(init)
39     for epoch in range(11):
40         for batch in range(n_batch):
41             batch_xs, batch_ys = mnist.train.next_batch(batch_size)
42             sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
43         acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
44         print(‘Iter = ‘ + str(epoch) +‘, Testing Accuracy = ‘ + str(acc))
45     #保存模型
46     saver.save(sess, ‘net/my_net.ckpt‘)

#保存路径中的文件为:
    checkpoint:保存当前网络状态的文件
    my_net.ckpt.data-00000-of-00001
    my_net.ckpt.index
    my_net.ckpt.meta:保存Graph结构的文件

#关于函数saver.save(),常用的参数就是前三个:
save(
    sess,  # 必需参数,Session对象
    save_path,  # 必需参数,存储路径
    global_step=None,  # 可以是Tensor, Tensor name, 整型数
    latest_filename=None,  # 协议缓冲文件名,默认为‘checkpoint‘,不用管
    meta_graph_suffix=‘meta‘,  # 图文件的后缀,默认为‘.meta‘,不用管
    write_meta_graph=True,  # 是否保存Graph
    write_state=True,  # 建议选择默认值True
    strip_default_attrs=False  # 是否跳过具有默认值的节点

saver.restore()加载已经训练好的模型

#举例:
通过加载刚才保存的训练好的手写数据集识别模型进行手写数据集的识别

 1 import os
 2 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
 3 import tensorflow as tf
 4 from tensorflow.examples.tutorials.mnist import input_data
 5
 6 mnist = input_data.read_data_sets(‘MNIST_data‘, one_hot=True)
 7 batch_size = 100
 8 n_batch = mnist.train.num_examples // batch_size
 9
10 x = tf.placeholder(tf.float32, [None, 784])
11 y = tf.placeholder(tf.float32, [None, 10])
12
13 W = tf.Variable(tf.zeros([784, 10]))
14 b = tf.Variable(tf.zeros([10]))
15 prediction = tf.nn.softmax(tf.matmul(x, W) + b)
16
17 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
18 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
19
20 init = tf.global_variables_initializer()
21
22 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
23 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
24
25 saver = tf.train.Saver()
26
27 with tf.Session() as sess:
28     sess.run(init)
29     print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
30     saver.restore(sess, ‘net/my_net.ckpt‘)
31     print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))

#执行结果:

0.098
0.9178

#直接得到的准确率相当低,通过加载训练好的模型,识别准确率大大提升。

2. 下载google图像识别网络inception-v3并查看结构

模型背景:
  Inception(v3) 模型是Google 训练好的最新一个图像识别模型,我们可以利用它来对我们的图像进行识别。

下载地址:
  https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip

文件描述:

  • classify_image_graph_def.pb 文件就是训练好的Inception-v3模型。
  • imagenet_synset_to_human_label_map.txt是类别文件,包含人类标签和uid之间的映射的文件。
  • imagenet_2012_challenge_label_map_proto.pbtxt是包含类号和uid之间的映射的文件。

代码实现

 1 import os
 2 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
 3 import tensorflow as tf
 4 import tarfile
 5 import requests
 6
 7 #inception模型下载地址
 8 inception_pretrain_model_url = ‘http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz‘
 9
10 #inception模型存放地址
11 inception_pretrain_model_dir = ‘inception_model‘
12 if not os.path.exists(inception_pretrain_model_dir):
13     os.makedirs(inception_pretrain_model_dir)
14 #获取文件名,以及文件路径
15 filename = inception_pretrain_model_url.split(‘/‘)[-1]
16 filepath = os.path.join(inception_pretrain_model_dir, filename)
17
18 #下载模型
19 if not os.path.exists(filepath):
20     print(‘download: ‘, filename)
21     r = requests.get(inception_pretrain_model_url, stream=True)
22     with open(filepath, ‘wb‘) as f:
23         for chunk in r.iter_content(chunk_size=1024):
24             if chunk:
25                 f.write(chunk)
26     print(‘finish: ‘, filename)
27 #解压文件
28 tarfile.open(filepath, ‘r:gz‘).extractall(inception_pretrain_model_dir)
29
30 #模型结构存放文件
31 log_dir = ‘inception_log‘
32 if not os.path.exists(log_dir):
33     os.makedirs(log_dir)
34
35 #classify_image_graph_def.pb为google训练好的模型
36 inception_graph_def_file = os.path.join(inception_pretrain_model_dir, ‘classify_image_graph_def.pb‘)
37 with tf.Session() as sess:
38     #创建一个图来存放google训练好的模型
39     with tf.gfile.FastGFile(inception_graph_def_file, ‘rb‘) as f:
40         graph_def = tf.GraphDef()
41         graph_def.ParseFromString(f.read())
42         tf.import_graph_def(graph_def, name=‘‘)
43     #保存图的结构
44     writer = tf.summary.FileWriter(log_dir, sess.graph)
45     writer.close()

#在下载过程中,下的特别慢,不知道是网络原因还是什么
#程序总卡着不动
#所以我就手动下载压缩包并进行解压

下载结果

3. 使用inception-v3做各种图像的识别

#代码实现:

 1 import os
 2 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
 3 import tensorflow as tf
 4 import numpy as np
 5 import re
 6 from PIL import Image
 7 import matplotlib.pyplot as plt
 8
 9 #这部分是对标签号和类别号文件进行一个预处理
10
11 class NodeLookup(object):
12     def __init__(self):
13         label_lookup_path = ‘inception_model/imagenet_2012_challenge_label_map_proto.pbtxt‘
14         uid_lookup_path = ‘inception_model/imagenet_synset_to_human_label_map.txt‘
15         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
16     def load(self, label_lookup_path, uid_lookup_path):
17         #加载分类字符串n********对应分类名称的文件
18         proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
19         uid_to_human={}
20         #一行一行读取数据
21         for line in proto_as_ascii_lines:
22             #去掉换行符
23             line = line.strip(‘\n‘)
24             #按照‘\t’进行分割
25             parsed_items = line.split(‘\t‘)
26             #获取分类编号
27             uid = parsed_items[0]
28             #获取分类名称
29             human_string = parsed_items[1]
30             #保存编号字符串n********与分类名称的映射关系
31             uid_to_human[uid] = human_string
32
33         #加载分类字符串n********对应分类编号1-1000的文件
34         proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
35         node_id_to_uid = {}
36         for line in proto_as_ascii:
37             if line.startswith(‘  target_class:‘):
38                 #获取分类编号1-1000
39                 target_class = int(line.split(‘: ‘)[1])
40             if line.startswith(‘  target_class_string:‘):
41                 #获取编号字符串nn********
42                 target_class_string = line.split(‘: ‘)[1]
43                 # 保存分类编号1-1000与编号字符串n********映射关系
44                 node_id_to_uid[target_class] = target_class_string[1:-2]
45         # 建立分类编号1-1000对应分类名称的映射关系
46         node_id_to_name = {}
47         for key, val in node_id_to_uid.items():
48             #获取分类名称
49             name = uid_to_human[val]
50             # 建立分类编号1-1000到分类名称的映射关系
51             node_id_to_name[key] = name
52         return node_id_to_name
53     # 传入分类编号1-1000返回分类名称
54     def id_to_string(self, node_id):
55         if node_id not in self.node_lookup:
56             return ‘‘
57         return self.node_lookup[node_id]
58
59 #创建一个图来存放google训练好的模型
60
61 with tf.gfile.FastGFile(‘inception_model/classify_image_graph_def.pb‘, ‘rb‘) as f:
62     graph_def = tf.GraphDef()
63     graph_def.ParseFromString(f.read())
64     tf.import_graph_def(graph_def, name=‘‘)
65
66 with tf.Session() as sess:
67     softmax_tensor = sess.graph.get_tensor_by_name(‘softmax:0‘)
68     #遍历目录
69     for root, dirs, files in os.walk(‘images/‘):
70         for file in files:
71             #载入图片
72             image_data = tf.gfile.FastGFile(os.path.join(root, file), ‘rb‘).read()
73             predictions = sess.run(softmax_tensor, {‘DecodeJpeg/contents:0‘: image_data})#图片格式是jpg格式
74             predictions = np.squeeze(predictions)#把结果转为1维数据
75
76             #打印图片路径及名称
77             image_path = os.path.join(root, file)
78             print(image_path)
79
80             # 显示图片
81             img = Image.open(image_path)
82             plt.imshow(img)
83             plt.axis(‘off‘)
84             plt.show()
85
86             #排序
87             top_k = predictions.argsort()[-5:][::-1]
88             node_lookup = NodeLookup()
89             for node_id in top_k:
90                 # 获取分类名称
91                 human_string = node_lookup.id_to_string(node_id)
92                 # 获取该分类的置信度
93                 score = predictions[node_id]
94                 print(‘%s(score = %.5f)‘ % (human_string, score))
95             print()

#执行结果:

images/1.jpg
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265)
badger(score = 0.00260)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205)
brown bear, bruin, Ursus arctos(score = 0.00102)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099)

images/2.jpg
French bulldog(score = 0.94474)
bull mastiff(score = 0.00559)
pug, pug-dog(score = 0.00352)
Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165)
boxer(score = 0.00116)

images/3.jpg
zebra(score = 0.94011)
tiger, Panthera tigris(score = 0.00080)
pencil box, pencil case(score = 0.00066)
hartebeest(score = 0.00059)
tiger cat(score = 0.00042)

images/4.jpg
hare(score = 0.87019)
wood rabbit, cottontail, cottontail rabbit(score = 0.04802)
Angora, Angora rabbit(score = 0.00612)
wallaby, brush kangaroo(score = 0.00181)
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056)

images/5.jpg
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047)
marmot(score = 0.00265)
mongoose(score = 0.00217)
weasel(score = 0.00201)
mink(score = 0.00199)

原文地址:https://www.cnblogs.com/guoruxin/p/10238018.html

时间: 2024-10-04 17:08:55

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用的相关文章

跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflow as tf #创建变量 v1 = tf.Variable(tf.random_normal([1, 2], name='v1')) v2 = tf.Variable(tf.random_normal([2, 3], name='v2')) #初始化变量 init_op = tf.global_v

图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2ra

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess

tensorflow-训练检查点tf.train.Saver

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Thu Sep 6 10:16:37 2018 @author: myhaspl @email:[email protected] """ import tensorflow as tf g1=tf.Graph() with g1.as_default(): with tf.name_scope("input_Va

tf.train.Saver()

1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删掉.如果值为None或0, 表示保存所有的checkpoint文件.默认值5(也就是说,保存最近的5个checkpoint文件). keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep_checkpoint文件,你还可能想每训练N小时保存一个checkpo

TF:利用TF的train.Saver载入曾经训练好的variables(W、b)以供预测新的数据

import tensorflow as tf import numpy as np W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights") b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases") saver = tf.train.Saver() with tf.

TF:利用TF的train.Saver将训练好的variables(W、b)保存到指定的index、meda文件

import tensorflow as tf import numpy as np W = tf.Variable([[2,1,8],[1,2,5]], dtype=tf.float32, name='weights') b = tf.Variable([[1,2,5]], dtype=tf.float32, name='biases') init= tf.global_variables_initializer() saver = tf.train.Saver() with tf.Sessi

机器学习之TensorFlow介绍

TensorFlow的概念很简单:使用python定义一个计算图,然后TensorFlow根据计算图生成高性能的c++代码. 如上图所示,使用图的方式实现了函数$f(x,y)=x^2y+y+2$的计算,在图中可以定义操作符和输入输出变量,基于此特性,TensorFlow能够实现分布式的计算,可以实现大量特征和实例的训练任务. 上图,显示了多个GPU计算的过程,TensorFlow有一下几个优点: 支持多平台,Windows, Linux,macOS,iOS,Android 提供了简单的pytho