tensorflow二进制文件读取与tfrecords文件读取

1、知识点

"""
TFRecords介绍:
    TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,
    更方便复制和移动,为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

CIFAR-10批处理结果存入tfrecords流程:
    1、构造存储器
         a)TFRecord存储器API:tf.python_io.TFRecordWriter(path) 写入tfrecords文件
            参数:
                path: TFRecords文件的路径
                return:写文件
            方法:
                write(record):向文件中写入一个字符串记录
                    record:字符串为一个序列化的Example,Example.SerializeToString()
                close():关闭文件写入器

    2、构造每一个样本的Example协议块
         a)tf.train.Example(features=None)写入tfrecords文件
                features:tf.train.Features类型的特征实例
                return:example格式协议块

         b)tf.train.Features(feature=None)构建每个样本的信息键值对
                feature:字典数据,key为要保存的名字,
                value为tf.train.Feature实例
                return:Features类型

         c)tf.train.Feature(**options)
                **options:例如
                    bytes_list=tf.train.BytesList(value=[Bytes])
                    int64_list=tf.train.Int64List(value=[Value])
                数据类型:
                    tf.train.Int64List(value=[Value])
                    tf.train.BytesList(value=[Bytes])
                    tf.train.FloatList(value=[value]) 

    3、写入序列化的Example
         writer.write(example.SerializeToString())

报错:
        1、ValueError: Protocol message Feature has no "Bytes_list" field.
                因为没有Bytes_list属性字段,只有bytes_list字段

读取tfrecords流程:
    1、构建文件队列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
    2、构造TFRecords阅读器
        reader = tf.TFRecordReader()
    3、解析Example,获取数据
        a) tf.parse_single_example(serialized,features=None,name=None)解析TFRecords的example协议内存块
            serialized:标量字符串Tensor,一个序列化的Example
            features:dict字典数据,键为读取的名字,值为FixedLenFeature
            return:一个键值对组成的字典,键为读取的名字
        b)tf.FixedLenFeature(shape,dtype) 类型只能是float32,int64,string
            shape:输入数据的形状,一般不指定,为空列表
            dtype:输入数据类型,与存储进文件的类型要一致
    4、转换格式,bytes解码
        image = tf.decode_raw(features["image"],tf.uint8)
        #固定图像大小,有利于批处理操作
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
        label = tf.cast(features["label"],tf.int32)
    5、批处理
        image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)

报错:
    1、ValueError: Shape () must have rank at least 1

"""

2、代码

# coding = utf-8
import tensorflow as tf
import  os

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("cifar_dir","./cifar10/", "文件的目录")
tf.app.flags.DEFINE_string("cifar_tfrecords", "./tfrecords/cifar.tfrecords", "存进tfrecords的文件")
class CifarRead(object):
    """
    完成读取二进制文件,写进tfrecords,读取tfrecords
    """
    def __init__(self,file_list):
        self.file_list = file_list
        #图片属性
        self.height = 32
        self.width = 32
        self.channel = 3

        #二进制字节
        self.label_bytes = 1
        self.image_bytes = self.height*self.width*self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_encode(self):
        """
        读取二进制文件,并进行解码操作
        :return:
        """
        #1、创建文件队列
        file_quque = tf.train.string_input_producer(self.file_list)
        #2、创建阅读器,读取二进制文件
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_quque)#key为文件名,value为文件内容
        #3、解码操作
        label_image = tf.decode_raw(value,tf.uint8)

        #分割图片和标签数据, tf.cast(),数据类型转换   tf.slice()tensor数据进行切片
        label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])

        #对图像进行形状改变
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])

        # 4、批处理操作
        image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
        print(image_batch,label_batch)
        return image_batch,label_batch

    def write_ro_tfrecords(self,image_batch,label_batch):
        """
        将读取的二进制文件写入 tfrecords文件中
        :param image_batch: 图像 (32,32,3)
        :param label_batch: 标签
        :return:
        """
        # 1、构造存储器
        writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)

        #循环写入
        for i in range(5):
            image = image_batch[i].eval().tostring()
            label = int(label_batch[i].eval()[0])
            # 2、构造每一个样本的Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])) ,
                "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label])) ,
            }))

            # 3、写入序列化的Example
            writer.write(example.SerializeToString())

        #关闭流
        writer.close()
        return None

    def read_from_tfrecords(self):
        """
        从tfrecords文件读取数据
        :return:
        """
        #1、构建文件队列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
        #2、构造TFRecords阅读器
        reader = tf.TFRecordReader()
        key , value = reader.read(file_queue)
        #3、解析Example
        features = tf.parse_single_example(value,features={
            "image":tf.FixedLenFeature([],tf.string),
            "label":tf.FixedLenFeature([],tf.int64)
        })
        #4、解码内容, 如果读取的内容格式是string需要解码, 如果是int64,float32不需要解码
        image = tf.decode_raw(features["image"],tf.uint8)
        #固定图像大小,有利于批处理操作
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
        label = tf.cast(features["label"],tf.int32)

        #5、批处理
        image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
        return image_batch,label_batch

if __name__ == ‘__main__‘:
    #################二进制文件读取###############
    # file_name = os.listdir(FLAGS.cifar_dir)
    # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    # cf = CifarRead(file_list)
    # image_batch, label_batch = cf.read_and_encode()
    # with tf.Session() as sess:
    #     # 创建协调器
    #     coord = tf.train.Coordinator()
    #     # 开启线程
    #     threads = tf.train.start_queue_runners(sess, coord=coord)
    #
    #     print(sess.run([image_batch, label_batch]))
    #     # 回收线程
    #     coord.request_stop()
    #     coord.join(threads)
    #############################################

    #####二进制文件读取,并写入tfrecords文件######
    # file_name = os.listdir(FLAGS.cifar_dir)
    # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    # cf = CifarRead(file_list)
    # image_batch, label_batch = cf.read_and_encode()
    # with tf.Session() as sess:
    #     # 创建协调器
    #     coord = tf.train.Coordinator()
    #     # 开启线程
    #     threads = tf.train.start_queue_runners(sess, coord=coord)
    #     #########保存文件到tfrecords##########
    #     cf.write_ro_tfrecords(image_batch, label_batch)
    #     #########保存文件到tfrecords##########
    #
    #     print(sess.run([image_batch, label_batch]))
    #     # 回收线程
    #     coord.request_stop()
    #     coord.join(threads)
    ##############################################

    #############从tfrecords文件读取###############
    file_name = os.listdir(FLAGS.cifar_dir)
    file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    cf = CifarRead(file_list)
    image_batch, label_batch = cf.read_from_tfrecords()
    with tf.Session() as sess:
        # 创建协调器
        coord = tf.train.Coordinator()
        # 开启线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        print(sess.run([image_batch, label_batch]))
        # 回收线程
        coord.request_stop()
        coord.join(threads)
    ##############################################

原文地址:https://www.cnblogs.com/ywjfx/p/10919461.html

时间: 2024-11-02 22:24:45

tensorflow二进制文件读取与tfrecords文件读取的相关文章

JXLS使用方法(文件上传读取)xlsx文件读取

1.官方文档:http://jxls.sourceforge.net/reference/reader.html 2.demo git地址:https://bitbucket.org/leonate/jxls-demo 3.maven添加 <dependency>    <groupId>org.jxls</groupId>    <artifactId>jxls-reader</artifactId>    <version>2.0

TensorFlow中数据读取之tfrecords

关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). 对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己

小谈——读取web资源文件的方式和路径问题

读取web资源文件的方式 a): 采用servletContext对象获得. 优点: 任意文件,任意路径都可获得 缺点: 必须在web环境下 // 拿到全局对象 ServletContext sc = this.getServletContext(); // 获取p1.properties文件的路径 String path = sc.getRealPath("/WEB-INF/classes/p1.properties"); b): 采用resourceBundle获得 优点: 非we

使用tensorflow中的Dataset来读取制作好的tfrecords文件

上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中

深度学习_1_Tensorflow_2_数据_文件读取

tensorflow 数据读取 队列和线程 文件读取, 图片处理 问题:大文件读取,读取速度, 在tensorflow中真正的多线程 子线程读取数据 向队列放数据(如每次100个),主线程学习,不用全部数据读取后,开始学习 队列与对垒管理器,线程与协调器 tf.FIFOQueue(capacity=None,dtypes=None,name="fifo_queue") # 先进先出队列 dequeue() 出队方法 enqueue(vals,name=None) 入队方法 enqueu

TensorFlow学习笔记(6)读取数据

Overview    之前几次推送的全部例程,使用的都是tensorflow预处理过的数据集,直接载入即可.例如: 然而实际中我们使用的通常不会是这种超级经典的数据集,如果我们有一组图像存储在磁盘上面,如何以mini-batch的形式把它们读取进来然后高效的送进网络训练?这次推送我们首先用tensorflow最底层的API处理这个问题,后面推送介绍高层API.高层API是对底层的进一步封装,用户可以不必关心过多细节.不过了解一下比较底层的API还是有好处的.当你有一组自己的数据的时候,你需要经

使用js-xlsx库,前端读取Excel报表文件

在实际开发中,经常会遇到导入Excel文件的需求,有的产品人想法更多,想要在前端直接判断文件内容格式是否正确,必填项是否已填写 依据HTML5的FileReader,可以使用新的API打开本地文件(参考这篇文章) FileReader.readAsBinaryString(Blob|File) FileReader.readAsText(Blob|File, opt_encoding) FileReader.readAsDataURL(Blob|File) FileReader.readAsAr

Java 读取、写入文件——解决乱码问题

读取文件流时,经常会遇到乱码的现象,造成乱码的原因当然不可能是一个,这里主要介绍因为文件编码格式而导致的乱码的问题.首先,明确一点,文本文件与二进制文件的概念与差异. 文本文件是基于字符编码的文件,常见的编码有ASCII编码,UNICODE编码.ANSI编码等等.二进制文件是基于值编码的文件,你可以根据具体应用,指定某个值是什么意思(这样一个过程,可以看作是自定义编码.) 因此可以看出文本文件基本上是定长编码的(也有非定长的编码如UTF-8).而二进制文件可看成是变长编码的,因为是值编码嘛,多少

Java文件读取大全

在此本人只搜集了四种文件读取的方法,分别是:按字节读取文件内容.按字符读取文件内容.按行读取文件内容.随机读取文件内容 以及给文件追加内容: 废话不多说,直接贴代码,希望能帮到一些人!如果有看不懂的可以加我QQ592652578,详聊. public class ReadFromFile {    1.按字节读取文件内容 /** * 以字节为单位读取文件,常用于读二进制文件,如图片.声音.影像等文件. */ public static void readFileByBytes(String fi