使用tensorflow中的Dataset来读取制作好的tfrecords文件

上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中可能就无法使用了,因此不建议大家使用这个函数,好了,下面就来看看是如何进行读取的吧。

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3
 4 #定义可以一次获得多张图像的函数
 5 def show_image(image_dir):
 6     plt.imshow(image_dir)
 7     plt.axis(‘on‘)
 8     plt.show()
 9
10 #单个record的解析函数
11 def decode_example(example):#,resize_height,resize_width,labels_nums):
12     features=tf.io.parse_single_example(example,features={
13         ‘image_raw‘:tf.io.FixedLenFeature([],tf.string),
14         ‘label‘:tf.io.FixedLenFeature([],tf.int64)
15     })
16     tf_image=tf.decode_raw(features[‘image_raw‘],tf.uint8)#这个其实就是图像的像素模式,之前我们使用矩阵来表示图像
17     tf_image=tf.reshape(tf_image,shape=[224,224,3])#对图像的尺寸进行调整,调整成三通道图像
18     tf_image=tf.cast(tf_image,tf.float32)*(1./255)#对图像进行归一化以便保持和原图像有相同的精度
19     tf_label=tf.cast(features[‘label‘],tf.int32)
20     tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#将label转化成用one_hot编码的格式
21     return tf_image,tf_label
22
23 def batch_test(tfrecords_file):
24     dataset=tf.data.TFRecordDataset(tfrecords_file)
25     dataset=dataset.map(decode_example)
26     dataset=dataset.shuffle(100).batch(4)
27     iterator=tf.compat.v1.data.make_one_shot_iterator(dataset)
28     batch_images,batch_labels=iterator.get_next()
29
30     init_op=tf.compat.v1.global_variables_initializer()
31     with tf.compat.v1.Session() as sess:
32         sess.run(init_op)
33         coord=tf.train.Coordinator()
34         threads=tf.train.start_queue_runners(coord=coord)
35         for i in range(4):
36             images,labels=sess.run([batch_images,batch_labels])
37             show_image(images[1,:,:,:])
38             print(‘shape:{},tpye:{},labels:{}‘.format(images.shape, images.dtype, labels))
39
40         coord.request_stop()
41         coord.join(threads)
42
43 if __name__==‘__main__‘:
44     tfrecords_file=‘D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords‘
45     resize_height=224
46     resize_width=224
47     batch_test(tfrecords_file)

我为了测试,写了batch_test这个函数,因为我想试一试看我做的tfrecords能不能被解析成功,如果你不想测试只想训练,那你直接把images_batch,和labels_batch放到网络中进行训练就可以了,还有一点要注意的,tf.global_variables_initializer()已经被tf.compat.v1.global_variables_initializer()所取代了,我做的时候不知道所以报了一个warning提示,同时tf.Sesssion()已经被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已经被tf.compat.v1.data.make_one_shot_iterator(dataset)  所代替,这些异常要注意,然后我只是将每个batch的第二张图片显示出来了,你也可以显示其他的,但是意义不大,反正只是测试一下解析成功与否,成功了我们就不需要纠结别的了。好啦,就是这样,接下来我会把这些东西放到网络中进行训练,再更新我的学习,就酱。

原文地址:https://www.cnblogs.com/daremosiranaihana/p/11444705.html

时间: 2024-11-06 21:30:50

使用tensorflow中的Dataset来读取制作好的tfrecords文件的相关文章

TensorFlow中数据读取之tfrecords

关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). 对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己

tensorflow二进制文件读取与tfrecords文件读取

1.知识点 """ TFRecords介绍: TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存, 更方便复制和移动,为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中 CIFAR-10批处理结果存入tfrecords流程: 1.构造存储器 a)TFRecord存储器API:tf.python_io.TFRecordWriter(path) 写入tfrecords文件 参数: path: TFRecords文件的路

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess

第二十二节,TensorFlow中的图片分类模型库slim的使用

Google在TensorFlow1.0,之后推出了一个叫slim的库,TF-slim是TensorFlow的一个新的轻量级的高级API接口.这个模块是在16年新推出的,其主要目的是来做所谓的"代码瘦身".它类似我们在TensorFlow模块中所介绍的tf.contrib.lyers模块,将很多常见的TensorFlow函数进行了二次封装,使得代码变得更加简洁,特别适用于构建复杂结构的深度神经网络,它可以用了定义.训练.和评估复杂的模型. 这里我们为什么要过来介绍这一节的内容呢?主要是

Tensorflow中使用CNN实现Mnist手写体识别

本文参考Yann LeCun的LeNet5经典架构,稍加ps得到下面适用于本手写识别的cnn结构,构造一个两层卷积神经网络,神经网络的结构如下图所示: 输入-卷积-pooling-卷积-pooling-全连接层-Dropout-Softmax输出 第一层卷积利用5*5的patch,32个卷积核,可以计算出32个特征.然后进行maxpooling.第二层卷积利用5*5的patch,64个卷积核,可以计算出64个特征.然后进行max pooling.卷积核的个数是我们自己设定,可以增加卷积核数目提高

第5章分布式系统模式 在 .NET 中使用 DataSet 实现 Data Transfer Object

要在 .NET Framework 中实现分布式应用程序.客户端应用程序需要显示一个窗体,该窗体要求对 ASP.NET Web Service 进行多个调用以满足单个用户请求.基于性能方面的考虑,我们发现,进行多个调用会降低应用程序性能.为了提高性能,需要通过对 Web Service 进行一次调用就能检索到用户请求所需的所有数据. 背景信息 注意:以下是在 .NET 中使用类型化 DataSet 实现 Data Transfer Object 中所描述的同一个示例应用程序. 下面是一个简化的

MapReduce中TextInputFormat分片和读取分片数据源码级分析

InputFormat主要用于描述输入数据的格式(我们只分析新API,即org.apache.hadoop.mapreduce.lib.input.InputFormat),提供以下两个功能: (1)数据切分:按照某个策略将输入数据切分成若干个split,以便确定MapTask个数以及对应的split: (2)为Mapper提供输入数据:读取给定的split的数据,解析成一个个的key/value对,供mapper使用. InputFormat有两个比较重要的方法:(1)List<InputSp

tensorflow中的共享变量(sharing variables)

为什么要使用共享变量? 当训练复杂模型时,可能经常需要共享大量的变量.例如,使用测试集来测试已训练好的模型性能表现时,需要共享已训练好模型的变量,如全连接层的权值. 而且我们还会遇到以下问题: 比如,我们创建了一个简单的图像滤波器模型.如果只使用tf.Variable,那么我们的模型可能如下 def my_image_filter(input_images): conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]), name="

Spring中对资源的读取支持

Resource简单介绍 注:所有操作基于配置好的Spring开发环境中. 在Spring中,最为核心的部分就是applicationContext.xml文件,而此配置文件中字符串的功能发挥到了极致. 在Java里面提供了最为原始的IO处理操作支持,但是传统的java.io包中只提供了inputStream与outputStream,虽然是最为常用的输入输出的处理类,但是用其进行一些复杂的资源读取非常麻烦.所以使用PrintStream,Scanner来改善这样的操作处理.但是即便这样,对网络