TF Boys (TensorFlow Boys ) 养成记(二)

TensorFlow 的 How-Tos,讲解了这么几点:

1. 变量:创建,初始化,保存,加载,共享;

2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Visualization)

3. 数据的读取;

4. 线程和队列;

5. 分布式的TensorFlow;

6. 增加新的Ops;

7. 自定义数据读取;

由于各种原因,本人只看了前5个部分,剩下的2个部分还没来得及看,时间紧任务重,所以匆匆发车了,以后如果有用到的地方,再回过头来研究。学习过程中深感官方文档的繁杂冗余极多多,特别是第三部分数据读取,又臭又长,花了我好久时间,所以我想把第三部分整理如下,方便乘客们。

TensorFlow 有三种方法读取数据:1)供给数据,用placeholder;2)从文件读取;3)用常量或者是变量来预加载数据,适用于数据规模比较小的情况。供给数据没什么好说的,前面已经见过了,不难理解,我们就简单的说一下从文件读取数据。

官方的文档里,从文件读取数据是一段很长的描述,链接层出不穷,看完这个链接还没看几个字,就出现了下一个链接。

自己花了很久才认识路,所以想把这部分总结一下,带带我的乘客们。

首先要知道你要读取的文件的格式,选择对应的文件读取器;

然后,定位到数据文件夹下,用

["file0", "file1"]        # or
[("file%d" % i) for i in range(2)])    # or
tf.train.match_filenames_once

选择要读取的文件的名字,用 tf.train.string_input_producer 函数来生成文件名队列,这个函数可以设置shuffle = Ture,来打乱队列,可以设置epoch = 5,过5遍训练数据。

最后,选择的文件读取器,读取文件名队列并解码,输入 tf.train.shuffle_batch 函数中,生成 batch 队列,传递给下一层。

1)假如你要读取的文件是像 CSV 那样的文本文件,用的文件读取器和解码器就是 TextLineReaderdecode_csv

2)假如你要读取的数据是像 cifar10 那样的 .bin 格式的二进制文件,就用 tf.FixedLengthRecordReadertf.decode_raw 读取固定长度的文件读取器和解码器。如下列出了我的参考代码:

class cifar10_data(object):
    def __init__(self, filename_queue):
        self.height = 32
        self.width = 32
        self.depth = 3
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.depth
        self.record_bytes = self.label_bytes + self.image_bytes
        self.label, self.image = self.read_cifar10(filename_queue)

    def read_cifar10(self, filename_queue):
        reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes)
        key, value = reader.read(filename_queue)
        record_bytes = tf.decode_raw(value, tf.uint8)
        label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32)
        image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes])
        image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width])
        image = tf.transpose(image_raw, (1,2,0))
        image = tf.cast(image, tf.float32)
        return label, image

def inputs(data_dir, batch_size, train = True, name = ‘input‘):

    with tf.name_scope(name):
        if train:
            filenames = [os.path.join(data_dir,‘data_batch_%d.bin‘ % ii)
                        for ii in range(1,6)]
            for f in filenames:
                if not tf.gfile.Exists(f):
                    raise ValueError(‘Failed to find file: ‘ + f)

            filename_queue = tf.train.string_input_producer(filenames)
            read_input = cifar10_data(filename_queue)
            images = read_input.image
            images = tf.image.per_image_whitening(images)
            labels = read_input.label
            num_preprocess_threads = 16
            image, label = tf.train.shuffle_batch(
                                    [images,labels], batch_size = batch_size,
                                    num_threads = num_preprocess_threads,
                                    min_after_dequeue = 20000, capacity = 20192)

            return image, tf.reshape(label, [batch_size])

        else:
            filenames = [os.path.join(data_dir,‘test_batch.bin‘)]
            for f in filenames:
                if not tf.gfile.Exists(f):
                    raise ValueError(‘Failed to find file: ‘ + f)

            filename_queue = tf.train.string_input_producer(filenames)
            read_input = cifar10_data(filename_queue)
            images = read_input.image
            images = tf.image.per_image_whitening(images)
            labels = read_input.label
            num_preprocess_threads = 16
            image, label = tf.train.shuffle_batch(
                                    [images,labels], batch_size = batch_size,
                                    num_threads = num_preprocess_threads,
                                    min_after_dequeue = 20000, capacity = 20192)

            return image, tf.reshape(label, [batch_size])
    

3)如果你要读取的数据是图片,或者是其他类型的格式,那么可以先把数据转换成 TensorFlow 的标准支持格式 tfrecords ,它其实是一种二进制文件,通过修改 tf.train.Example 的Features,将 protocol buffer 序列化为一个字符串,再通过 tf.python_io.TFRecordWriter 将序列化的字符串写入 tfrecords,然后再用跟上面一样的方式读取tfrecords,只是读取器变成了tf.TFRecordReader,之后通过一个解析器tf.parse_single_example ,然后用解码器 tf.decode_raw 解码。

例如,对于生成式对抗网络GAN,我采用了这个形式进行输入,部分代码如下:

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

def convert_to(data_path, name):

    """
    Converts s dataset to tfrecords
    """

    rows = 64
    cols = 64
    depth = DEPTH
    for ii in range(12):
        writer = tf.python_io.TFRecordWriter(name + str(ii) + ‘.tfrecords‘)
        for img_name in os.listdir(data_path)[ii*16384 : (ii+1)*16384]:
            img_path = data_path + img_name
            img = Image.open(img_path)
            h, w = img.size[:2]
            j, k = (h - OUTPUT_SIZE) / 2, (w - OUTPUT_SIZE) / 2
            box = (j, k, j + OUTPUT_SIZE, k+ OUTPUT_SIZE)

            img = img.crop(box = box)
            img = img.resize((rows,cols))
            img_raw = img.tobytes()
            example = tf.train.Example(features = tf.train.Features(feature = {
                                    ‘height‘: _int64_feature(rows),
                                    ‘weight‘: _int64_feature(cols),
                                    ‘depth‘: _int64_feature(depth),
                                    ‘image_raw‘: _bytes_feature(img_raw)}))
            writer.write(example.SerializeToString())
        writer.close()

def read_and_decode(filename_queue):

    """
    read and decode tfrecords
    """

#    filename_queue = tf.train.string_input_producer([filename_queue])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialized_example,features = {
                        ‘image_raw‘:tf.FixedLenFeature([], tf.string)})
    image = tf.decode_raw(features[‘image_raw‘], tf.uint8)

    return image

这里,我的data_path下面有16384*12张图,通过12次写入Example操作,把图片数据转化成了12个tfrecords,每个tfrecords里面有16384张图。

4)如果想定义自己的读取数据操作,请参考https://www.tensorflow.org/how_tos/new_data_formats/

好了,今天的车到站了,请带好随身物品准备下车,明天老司机还有一趟车,请记得准时乘坐,车不等人。

参考文献:

1. https://www.tensorflow.org/how_tos/

2. 没了

时间: 2024-10-14 13:16:27

TF Boys (TensorFlow Boys ) 养成记(二)的相关文章

TF Boys (TensorFlow Boys ) 养成记(一)

本资料是在Ubuntu14.0.4版本下进行,用来进行图像处理,所以只介绍关于图像处理部分的内容,并且默认TensorFlow已经配置好,如果没有配置好,请参考官方文档配置安装,推荐用pip安装.关于配置TensorFlow,官方已经说得很详细了,我这里就不啰嗦了.官方教程看这里:https://www.tensorflow.org/get_started/os_setup 如果安装了GPU版本的TensorFlow,还需要配置Cuda,关于Cuda安装看这里:https://www.tenso

TF Boys (TensorFlow Boys ) 养成记(五)

郑重声明:此文为本人原创,转载请注明出处:http://www.cnblogs.com/Charles-Wan/p/6207039.html 有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输入如下代码: from __future__ import absolute_import # 绝对导入 from __future__ import div

TF Boys (TensorFlow Boys ) 养成记(三)

上次说到了 TensorFlow 从文件读取数据,这次我们来谈一谈变量共享的问题. 为什么要共享变量?我举个简单的例子:例如,当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1.也就是说,生成图像和真实图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制. 变量共享主要涉及到两个函数: tf.get_variab

TF Boys (TensorFlow Boys ) 养成记(六)

圣诞节玩的有点嗨,差点忘记更新.祝大家昨天圣诞节快乐,再过几天元旦节快乐. 来继续学习,在/home/your_name/TensorFlow/cifar10/ 下新建文件夹cifar10_train,用来保存训练时的日志logs,继续在/home/your_name/TensorFlow/cifar10/ cifar10.py中输入如下代码: def train(): # global_step global_step = tf.Variable(0, name = 'global_step'

前端工程师养成记:开发环境搭建(Sublime Text必备插件推荐)

为了让自己更像一个前端工程师,决定从开发环境开始武装自己.本文将介绍前段工程师开发的一些利器的安装步骤,主要包括了: 1.Node.js的安装 2.Grunt的安装及常用插件 3.Sublime Text的安装及必备插件 一.Node.js的安装 Node.js就是一堆前端工程师捧红的,所以装上这个嘛,主要不是自己需要使用Node.js而是一堆工具对他的依赖. Windows下安装步骤很简单: 1.去到http://nodejs.org/下载最新的安装包,安装. 2.在CMD下运行,node和n

中产阶级养成记:现代人需要的8点能力素养(一)(不服来战,欢迎勾搭)

首先,要说明"中产阶级养成记",这个确实有点"标题党"了.我自认为,关于以下几点的能力素养,对从贫穷晋升到中产阶级方面,有很大帮助,至少我现在是这么想的,也认为是可行的. 自己的家庭或者说家族,本来就是那种平民百姓,在早期属于"农民",最近些年,属于"半农半工","全工"的状态. 作为整个家庭,甚至是家族,几代人中间唯一的一个有较高含金量的"大学生" ,我最想做的事情之一,就是想在经济方

【活动】DevOps直播技术架构养成记

背景 半月前,参加了UCloud直播云的活动,主题"DevOps|直播技术架构养成记",很是不错的.能够整理出本篇博文,非常感谢参加会议的朋友们在微信群中提供的非常好的资料,以作分享. Now, go into! 低延迟.秒开? 网络视频直播存在已有很长一段时间,随着移动上下行带宽提升及资费的下调,视频直播被赋予了更多娱乐和社交的属性,人们享受随时随地进行直播和观看,主播不满足于单向的直播,观众则更渴望互动,直播的打开时间和延迟变成了影响产品功能发展重要指标.那么,问题来了:如何实现低

硅谷行记二:走进百度美国研发中心

硅谷行记二:走进百度美国研发中心 牛智超02月01日 12:44 分享到:                                                                                                                                                      4 百度                                       百家                  

2016级算法第六次上机-C.AlvinZH的学霸养成记II

1032 AlvinZH的学霸养成记II 思路 中等题,贪心. 所有课程按照DDL的大小来排序. 维护一个当前时间curTime,初始为0. 遍历课程,curTime加上此课程持续时间d,如果这时curTime大于此课程DDL,表示无法学习此课程,但是我们不减去此课程,而是减去用时最长的那门课程(优先队列队首,课时最长). 贪心: 假设当前课程为B,被替换课程为A,则有A.d≥B.d,A.e≤B.e.既然curTime+A.d≤A.e,那么curTime+B.d≤B.e绝对成立,保证了B的合法性