利用Tensorflow读取二进制CIFAR-10数据集

使用Tensorflow读取CIFAR-10二进制数据集

觉得有用的话,欢迎一起讨论相互学习~Follow Me

参考文献

Tensorflow官方文档

tf.transpose函数解析

tf.slice函数解析

CIFAR10/CIFAR100数据集介绍

tf.train.shuffle_batch函数解析

Python urllib urlretrieve函数解析

import os
import tarfile
import tensorflow as tf
from six.moves import urllib
from tensorflow.python.framework import ops

ops.reset_default_graph()

# 更改工作目录
abspath = os.path.abspath(__file__)  # 获取当前文件绝对地址
# E:\GitHub\TF_Cookbook\08_Convolutional_Neural_Networks\03_CNN_CIFAR10\ostest.py
dname = os.path.dirname(abspath)  # 获取文件所在文件夹地址
# E:\GitHub\TF_Cookbook\08_Convolutional_Neural_Networks\03_CNN_CIFAR10
os.chdir(dname)  # 转换目录文件夹到上层

# Start a graph session
# 初始化Session
sess = tf.Session()

# 设置模型超参数
batch_size = 128  # 批处理数量
data_dir = ‘temp‘  # 数据目录
output_every = 50  # 输出训练loss值
generations = 20000  # 迭代次数
eval_every = 500  # 输出测试loss值
image_height = 32  # 图片高度
image_width = 32  # 图片宽度
crop_height = 24  # 裁剪后图片高度
crop_width = 24  # 裁剪后图片宽度
num_channels = 3  # 图片通道数
num_targets = 10  # 标签数
extract_folder = ‘cifar-10-batches-bin‘

# 指数学习速率衰减参数
learning_rate = 0.1  # 学习率
lr_decay = 0.1  # 学习率衰减速度
num_gens_to_wait = 250.  # 学习率更新周期

# 提取模型参数
image_vec_length = image_height*image_width*num_channels  # 将图片转化成向量所需大小
record_length = 1 + image_vec_length  # ( + 1 for the 0-9 label)

# 读取数据
data_dir = ‘temp‘
if not os.path.exists(data_dir):  # 当前目录下是否存在temp文件夹
    os.makedirs(data_dir)  # 如果当前文件目录下不存在这个文件夹,创建一个temp文件夹
#  设定CIFAR10下载路径
cifar10_url = ‘http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz‘

# 检查这个文件是否存在,如果不存在下载这个文件
data_file = os.path.join(data_dir, ‘cifar-10-binary.tar.gz‘)
# temp\cifar-10-binary.tar.gz
if os.path.isfile(data_file):
    pass
else:
    # 回调函数,当连接上服务器、以及相应的数据块传输完毕时会触发该回调,我们可以利用这个回调函数来显示当前的下载进度。
    # block_num已经下载的数据块数目,block_size数据块大小,total_size下载文件总大小

    def progress(block_num, block_size, total_size):
        progress_info = [cifar10_url, float(block_num*block_size)/float(total_size)*100.0]
        print(‘\r Downloading {} - {:.2f}%‘.format(*progress_info), end="")

    # urlretrieve(url, filename=None, reporthook=None, data=None)
    # 参数 finename 指定了保存本地路径(如果参数未指定,urllib会生成一个临时文件保存数据。)
    # 参数 reporthook 是一个回调函数,当连接上服务器、以及相应的数据块传输完毕时会触发该回调,我们可以利用这个回调函数来显示当前的下载进度。
    # 参数 data 指 post 到服务器的数据,该方法返回一个包含两个元素的(filename, headers)元组,filename 表示保存到本地的路径,header 表示服务器的响应头。
    # 此处 url=cifar10_url,filename=data_file,reporthook=progress

    filepath, _ = urllib.request.urlretrieve(cifar10_url, data_file, progress)
    # 解压文件
    tarfile.open(filepath, ‘r:gz‘).extractall(data_dir)

# Define CIFAR reader
# 定义CIFAR读取器
def read_cifar_files(filename_queue, distort_images=True):
    reader = tf.FixedLengthRecordReader(record_bytes=record_length)
    # 返回固定长度的文件记录 record_length函数参数为一条图片信息即1+32*32*3
    key, record_string = reader.read(filename_queue)
    # 此处调用tf.FixedLengthRecordReader.read函数返回键值对
    record_bytes = tf.decode_raw(record_string, tf.uint8)
    # 读出来的原始文件是string类型,此处我们需要用decode_raw函数将String类型转换成uint8类型
    image_label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)
    # 见slice函数用法,取从0号索引开始的第一个元素。并将其转化为int32型数据。其中存储的是图片的标签

    # 截取图像
    image_extracted = tf.reshape(tf.slice(record_bytes, [1], [image_vec_length]),
                                 [num_channels, image_height, image_width])
    # 从1号索引开始提取图片信息。这和此数据集存储图片信息的格式相关。
    # CIFAR-10数据集中
    """第一个字节是第一个图像的标签,它是一个0-9范围内的数字。接下来的3072个字节是图像像素的值。
       前1024个字节是红色通道值,下1024个绿色,最后1024个蓝色。值以行优先顺序存储,因此前32个字节是图像第一行的红色通道值。
       每个文件都包含10000个这样的3073字节的“行”图像,但没有任何分隔行的限制。因此每个文件应该完全是30730000字节长。"""

    # Reshape image
    image_uint8image = tf.transpose(image_extracted, [1, 2, 0])
    # 详见tf.transpose函数,将[channel,image_height,image_width]转化为[image_height,image_width,channel]的数据格式。
    reshaped_image = tf.cast(image_uint8image, tf.float32)
    # 将图片剪裁或填充至合适大小
    final_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, crop_width, crop_height)

    if distort_images:
        # 将图像水平随机翻转,改变亮度和对比度。
        final_image = tf.image.random_flip_left_right(final_image)
        final_image = tf.image.random_brightness(final_image, max_delta=63)
        final_image = tf.image.random_contrast(final_image, lower=0.2, upper=1.8)

    # 对图片做标准化处理
    """Linearly scales `image` to have zero mean and unit norm.
    This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
    of all values in image, and `adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))`.
    `stddev` is the standard deviation of all values in `image`.
    It is capped away from zero to protect against division by 0 when handling uniform images."""
    final_image = tf.image.per_image_standardization(final_image)
    return (final_image, image_label)

# Create a CIFAR image pipeline from reader
# 从阅读器中构造CIFAR图片管道
def input_pipeline(batch_size, train_logical=False):
    # train_logical标志用于区分读取训练和测试数据集
    if train_logical:
        files = [os.path.join(data_dir, extract_folder, ‘data_batch_{}.bin‘.format(i)) for i in range(1, 6)]
    #  data_dir=tmp
    # extract_folder=cifar-10-batches-bin
    else:
        files = [os.path.join(data_dir, extract_folder, ‘test_batch.bin‘)]
    filename_queue = tf.train.string_input_producer(files)
    image, label = read_cifar_files(filename_queue)
    print(train_logical, ‘after read_cifar_files ops image‘, sess.run(tf.shape(image)))
    print(train_logical, ‘after read_cifar_files ops label‘, sess.run(tf.shape(label)))
    # min_after_dequeue defines how big a buffer we will randomly sample
    #   from -- bigger means better shuffling but slower start up and more
    #   memory used.
    # capacity must be larger than min_after_dequeue and the amount larger
    #   determines the maximum we will prefetch.  Recommendation:
    #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
    min_after_dequeue = 5000
    capacity = min_after_dequeue + 3*batch_size
    # 批量读取图片数据
    example_batch, label_batch = tf.train.shuffle_batch([image, label],
                                                        batch_size=batch_size,
                                                        capacity=capacity,
                                                        min_after_dequeue=min_after_dequeue)
    print(train_logical, ‘after shuffle_batch ops image‘, sess.run(tf.shape(image)))
    print(train_logical, ‘after shuffle_batch ops example_batch‘, sess.run(tf.shape(example_batch)))
    print(train_logical, ‘after shuffle_batch ops label‘, sess.run(tf.shape(label)))
    print(train_logical, ‘after shuffle_batch ops label_batch‘, sess.run(tf.shape(label_batch)))

    return (example_batch, label_batch)

# 获取数据
print(‘Getting/Transforming Data.‘)
# 初始化数据管道获取训练数据和对应标签
images, targets = input_pipeline(batch_size, train_logical=True)
# 获取测试数据和对应标签
test_images, test_targets = input_pipeline(batch_size, train_logical=False)

sess.close()

# True after read_cifar_files ops image [24 24  3]
# True after read_cifar_files ops label [1]
# True after shuffle_batch ops image [24 24  3]
# True after shuffle_batch ops example_batch [128  24  24   3]
# True after shuffle_batch ops label [1]
# True after shuffle_batch ops label_batch [128   1]
# False after read_cifar_files ops image [24 24  3]
# False after read_cifar_files ops label [1]
# False after shuffle_batch ops image [24 24  3]
# False after shuffle_batch ops example_batch [128  24  24   3]
# False after shuffle_batch ops label [1]
# False after shuffle_batch ops label_batch [128   1]

原文地址:https://www.cnblogs.com/cloud-ken/p/8458192.html

时间: 2024-10-29 12:03:25

利用Tensorflow读取二进制CIFAR-10数据集的相关文章

利用Tensorflow进行自然语言处理(NLP)系列之一Word2Vec

写在前面的话(可略过): 一直想写下.整理下利用Tensorflow或Keras工具进行自然语言处理(NLP)方面的文章,对比和纠结了一段时间,发现博众家之长不如静下心来一步一个脚印地去看一本书来得更实在,虽然慢但是心里相对踏实些.近期刚把Thushan Ganegedara写的<Natural Language Processing with TensorFlow>(2018年5月第一次出版),目前没看到中文版.讲真,看原版书确实很耗费精力,但原版书的好处是可以原汁原味地探索.写博文的过程中

在C#下使用TensorFlow.NET训练自己的数据集

在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分类 ,可以直接移植该代码在 CPU 或 GPU 下使用,并针对你们自己本地的图像数据集进行训练和推理.TensorFlow.NET是基于 .NET Standard 框架的完整实现的TensorFlow,可以支持 .NET Framework 或 .NET CORE , TensorFlow.NET

利用keras搭建CNN进行mnist数据集分类

当接触深度学习算法的时候,大家都很想自己亲自实践一下这个算法,但是一想到那些复杂的程序,又感觉心里面很累啊,又要学诸如tensorflow.theano这些框架.那么,有没有什么好东西能够帮助我们快速搭建这个算法呢?当然是有咯!,现如今真不缺少造轮子的大神,so,我强烈向大家推荐keras,Keras是一个高层神经网络API,Keras由纯Python编写而成并基Tensorflow或Theano.Keras为支持快速实验而生,能够把你的idea迅速转换为结果. 具体keras的安装与使用,请参

利用TensorFlow实现多元逻辑回归

利用TensorFlow实现多元逻辑回归,代码如下: import tensorflow as tf import numpy as np from sklearn.linear_model import LogisticRegression from sklearn import preprocessing # Read x and y x_data = np.loadtxt("ex4x.dat").astype(np.float32) y_data = np.loadtxt(&qu

利用XPath读取Xml文件

之所以要引入XPath的概念,目的就是为了在匹配XML文档结构树时能够准确地找到某一个节点元素.可以把XPath比作文件管理路径:通过文件管理路 径,可以按照一定的规则查找到所需要的文件:同样,依据XPath所制定的规则,也可以很方便地找到XML结构文档树中的任何一个节点. 不过,由于XPath可应用于不止一个的标准,因此W3C将其独立出来作为XSLT的配套标准颁布,它是XSLT以及我们后面要讲到的XPointer的重要组成部分. 在介绍XPath的匹配规则之前,我们先来看一些有关XPath的基

已知rand7()&#160;可以产生1~7的7个数(均匀概率),利用rand7()产生rand10()1~10(均匀概率)

题目:已知rand7() 可以产生1~7的7个数(均匀概率),利用rand7()产生rand10()1~10(均匀概率). 解析:首先利用rand7()产生1-5的5个数,每个数的概率为1/5,然后在这5个数的基础上再以1/2的概率加上5,这样就能以1/10的概率产生每一个数. 答案: int rand10() { int tmp1,tmp2; do { tmp1=rand7(); }while(tmp1>5); do { tmp2=rand7(); }while(tmp2>2); retur

利用 Linux 系统生成随机密码的10种方法

通常情况下大家生成密码都好困惑,一来复杂程度不够会不安全,复杂程度够了又不能手动随便敲击键盘打出一同字符(但通常情况下这些字符是有规律的),使用1password 或者 keepass 这种软件生成也可以,不过貌似1password 要收费,既然这样我们就玩一下好玩的用 linux 来生成随机密码玩玩吧; Linux操作系统的一大优点是对于同样一件事情,你可以使用高达数百种方法来实现它.例如,你可以通过数十种方法来生成随机密码.本文将介绍生成随机密码的十种方法. 1. 使用SHA算法来加密日期,

数据库读取二进制图片显示到PictureBox中

1.已知路径,加载本地图片到Image中 Image img = Image.FromFile("路径"); 2.数据库中读取二进制图片 string strSql = "Select Top 1 ImageContent From TT_ImageFileSave)"; Byte[] byteImage = new Byte[0]; byteImage = (Byte[])(DbHelperSQL.GetSingle(strSql)); MemoryStream

利用OpenXml读取、导出Excel

OpenXml是通过 XML 文档提供行集视图.由于OPENXML 是行集提供程序,因此可在会出现行集提供程序(如表.视图或 OPENROWSET 函数)的 Transact-SQL 语句中使用 OPENXML. 效果图: 使用它的时候,首选的下载安装这个程序集,下载地址:http://www.microsoft.com/en-us/download/details.aspx?id=30425 安装好了在项目当中引用如下2个 前台弹出框用的是 jBox这个js插件,我用了ajax请求的方式来上传