TensorFlow.org教程笔记(二) DataSets 快速入门

本文翻译自www.tensorflow.org的英文教程。

tf.data 模块包含一组类,可以让你轻松加载数据,操作数据并将其输入到模型中。本文通过两个简单的例子来介绍这个API

  • 从内存中的numpy数组读取数据。
  • 从csv文件中读取行

基本输入

对于刚开始使用tf.data,从数组中提取切片(slices)是最简单的方法。

笔记(1)TensorFlow初上手里提到了训练输入函数train_input_fn,该函数将数据传输到Estimator中:

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Build the Iterator, and return the read end of the pipeline.
    return dataset.make_one_shot_iterator().get_next()

让我们进一步来看看这个过程。

参数

这个函数需要三个参数。期望“array”的参数几乎可以接受任何可以使用numpy.array转换为数组的东西。其中有一个例外是对Datasets有特殊意义的元组(tuple)。

  • features :一个包含原始特征输入的{‘feature_name‘:array}的字典(或者pandas.DataFrame)
  • labels :一个包含每个样本标签的数组
  • batch_size:指示所需批量大小的整数。

在前面的笔记中,我们使用iris_data.load_data()函数加载了鸢尾花的数据。你可以运行下面的代码来获取结果:

import iris_data

# Fetch the data.
train, test = iris_data.load_data()
features, labels = train

然后你可以将数据输入到输入函数中,类似这样:

batch_size = 100
iris_data.train_input_fn(features, labels, batch_size)

我们来看看这个train_input_fn

切片(Slices)

在最简单的情况下,tf.data.Dataset.from_tensor_slices函数接收一个array并返回一个表示array切片的tf.data.Dataset。例如,mnist训练集的shape是(60000, 28, 28)。将这个array传递给from_tensor_slices将返回一个包含60000个切片的数据集对象,每个切片大小为28X28的图像。(其实这个API就是把array的第一维切开)。

这个例子的代码如下:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

将产生下面的结果:显示数据集中项目的type和shape。注意,数据集不知道它含有多少个sample。

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

上面的数据集代表了简单数组的集合,但Dataset的功能还不止如此。Dataset能够透明地处理字典或元组的任何嵌套组合。例如,确保features是一个标准的字典,你可以将数组字典转换为字典数据集。

先来回顾下features,它是一个pandas.DataFrame类型的数据:

SepalLength SepalWidth PetalLength PetalWidth
0.6 0.8 0.9 1
... ... ... ...

dict(features)是一个字典,它的形式如下:

{key:value,key:value...}  # key是string类型的列名,即SepalLength等
            # value是pandas.core.series.Series类型的变量,即数据的一个列,是一个标量

对它进行切片

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)

结果如下:

<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>

这里我们看到,当数据集包含结构化元素时,数据集的形状和类型采用相同的结构。该数据集包含标量字典,所有类型为tf.float64。

train_input_fn的第一行使用了相同的函数,但它增加了一层结构-----创建了一个包含(feature, labels)对的数据集

我们继续回顾labels的结构,它其实是一个pandas.core.series.Series类型的变量,即它与dict(features)的value是同一类型。且维度一致,都是标量,长度也一致。

以下代码展示了这个dataset的形状:

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操纵

对于目前的数据集,将以固定的顺序遍历数据一次,并且每次只生成一个元素。在它可以被用来训练之前,还需做进一步处理。幸运的是,tf.data.Dataset类提供了接口以便于更好地在训练之前准备数据。输入函数的下一行利用了以下几种方法:

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle方法使用一个固定大小的缓冲区来随机对数据进行shuffle。设置大于数据集中sample数目的buffer_size可以确保数据完全混洗。鸢尾花数据集只包含150个数据。

repeat方法在读取到组后的数据时重启数据集。要限制epochs的数量,可以设置count参数。

batch方法累计样本并堆叠它们,从而创建批次。这个操作的结果为这批数据的形状增加了一个维度。新维度被添加为第一维度。以下代码是早期使用mnist数据集上的批处理方法。这使得28x28的图像堆叠为三维的数据批次。

print(mnist_ds.batch(100))
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

请注意,数据集具有未知的批量大小,因为最后一批的元素数量较少。

train_input_fn中,批处理后,数据集包含一维向量元素,其中每个标量先前都是:

print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

返回值

每个Estimatortrainpredictevaluate方法都需要输入函数返回一个包含Tensorflow张量的(features, label)对。train_input_fn使用以下代码将数据集转换为预期的格式:

# Build the Iterator, and return the read end of the pipeline.
features_result, labels_result = dataset.make_one_shot_iterator().get_next()

结果是TensorFlow张量的结构,匹配数据集中的项目层。

print((features_result, labels_result))
({
    ‘SepalLength‘: <tf.Tensor ‘IteratorGetNext:2‘ shape=(?,) dtype=float64>,
    ‘PetalWidth‘: <tf.Tensor ‘IteratorGetNext:1‘ shape=(?,) dtype=float64>,
    ‘PetalLength‘: <tf.Tensor ‘IteratorGetNext:0‘ shape=(?,) dtype=float64>,
    ‘SepalWidth‘: <tf.Tensor ‘IteratorGetNext:3‘ shape=(?,) dtype=float64>},
Tensor("IteratorGetNext_1:4", shape=(?,), dtype=int64))

读取CSV文件

Dataset最常见的实际用例是按流的方式从磁盘上读取文件。tf.data模块包含各种文件读取器。让我们来看看如何使用Dataset从csv文件中分析鸢尾花数据集。

以下对iris_data.maybe_download函数的调用在需要时会下载数据,并返回下载结果文件的路径名称:

import iris_data
train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn函数包含使用Dataset解析csv文件的替代实现。

构建数据集

我们首先构建一个TextLineDataset对象,一次读取一行文件。然后,我们调用skip方法跳过文件第一行,它包含一个头部,而不是样本:

ds = tf.data.TextLineDataset(train_path).skip(1)

构建csv行解析器

最终,我们需要解析数据集中的每一行,以产生必要的(features, label)对。

我们将开始构建一个函数来解析单个行。

下面的iris_data.parse_line函数使用tf.decode_csv函数和一些简单的代码完成这个任务:

我们必须解析数据集中的每一行以生成必要的(features, label)对。以下的_parse_line函数调用tf.decode_csv将单行解析为其featureslabel。由于Estimator要求将特征表示为字典,因此我们依靠python的内置字典和zip函数来构建该字典。特征名是该字典的key。然后我们调用字典的pop方法从特征字典中删除标签字段。

# Metadata describing the text columns
COLUMNS = [‘SepalLength‘, ‘SepalWidth‘,
           ‘PetalLength‘, ‘PetalWidth‘,
           ‘label‘]
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS, fields))

    # Separate the label from the features
    label = features.pop(‘label‘)

    return features, label

解析行

Datasets有很多方法用于在数据传输到模型时处理数据。最常用的方法是map,它将转换应用于Dataset的每个元素。

map方法使用一个map_func参数来描述Dataset中每个项目应该如何转换。

因此为了解析流出csv文件的行,我们将_parse_line函数传递给map方法:

ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>

现在的数据集不是简单的标量字符串,而是包含了(features, label)对。

iris_data.csv_input_fn函数其余部分与基本输入部分中涵盖的iris_data.train_input_fn相同。

试试看

该函数可以用来替代iris_data.train_input_fn。它可以用来提供一个如下的Estimator

train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[:-1]]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
                                    n_classes = 3)
# Train the estimator
batch_size = 100
est.train(
    steps=1000
    input_fn=lambda:iris_data.csv_input_fn(train_path, batch_size))

Estimator期望input_fn不带任何参数。为了解除这个限制,我们使用lambda来捕获参数并提供预期的接口。

总结

tf.data模块提供了一组用于轻松读取各种来源数据的类和函数。此外,tf.data具有简单强大的方法来应用各种标准和自定义转换。

现在你已经了解如何有效地将数据加载到Estimator中的基本想法。接下来考虑以下文档:

  • 创建自定义估算器,演示如何构建自己的自定义估算器模型。
  • 低层次简介,演示如何使用TensorFlow的低层API直接实验tf.data.Datasets
  • 导入详细了解数据集附加功能的数据。

原文地址:https://www.cnblogs.com/HolyShine/p/8673322.html

时间: 2024-09-29 19:38:26

TensorFlow.org教程笔记(二) DataSets 快速入门的相关文章

git-github-TortoiseGit综合使用教程(二)快速入门

一:建立版本库 在github网站上创建一个版本库,并复制clone地址. [email protected]:jackadam1981/Flask_Base.git https://github.com/jackadam1981/Flask_Base.git 这种结尾是git的就是git协议的地址了. 二:下载下来(空库) 在你准备写程序的目录,鼠标右键,Git 克隆, 三:修改后提交 四:打标签,tag 五:推送 六:回退 七:分支 八:合并 九:暂存 十:总结 原文地址:https://w

【Solr基础教程之一】Solr快速入门

一.Solr学习相关资料 1.官方材料 (1)快速入门:http://lucene.apache.org/solr/4_9_0/tutorial.html,以自带的example项目快速介绍发Solr的基础使用. (2)API:http://lucene.apache.org/solr/4_9_0/index.html (3)reference:PDF格式,apache-solr-ref-guide-4.9.pdf 2.书籍 (1)Solr in Action,基于4.7版本,极力推荐,此书适合

Bmob移动后端云服务平台--Android从零开始--(二)android快速入门

Bmob移动后端云服务平台--Android从零开始--(二)android快速入门 上一篇博文我们简单介绍何为Bmob移动后端服务平台,以及其相关功能和优势.本文将利用Bmob快速实现简单例子,进一步了解它的强大之处. 一.准备工作 1.注册Bmob账号 在网址栏输入www.bmob.cn或者在百度输入Bmob进行搜索,打开Bmob官网后,点击右上角的"注册",在跳转页面填入你的姓名.邮箱.设置密码,确认后到你的邮箱激活Bmob账户,你就可以用Bmob轻松开发应用了. 2.网站后台创

Yii2框架RESTful API教程(一) - 快速入门

前不久做一个项目,是用Yii2框架写一套RESTful风格的API,就去查了下<Yii 2.0 权威指南 >,发现上面写得比较简略.所以就在这里写一篇教程贴,希望帮助刚接触Yii2框架RESTful的小伙伴快速入门. 一.目录结构 实现一个简单地RESTful API只需用到三个文件.目录如下: frontend ├─ config │ └ main.php ├─ controllers │ └ BookController.php └─ models └ Book.php 二.配置URL规则

MyBatis学习笔记(一)——MyBatis快速入门

一.Mybatis介绍 MyBatis是一个支持普通SQL查询,存储过程和高级映射的优秀持久层框架.MyBatis消除了几乎所有的JDBC代码和参数的手工设置以及对结果集的检索封装.MyBatis可以使用简单的XML或注解用于配置和原始映射,将接口和Java的POJO(Plain Old Java Objects,普通的Java对象)映射成数据库中的记录. 二.mybatis快速入门 2.1.准备开发环境 1.创建测试项目,普通java项目或者是JavaWeb项目均可,如下图所示: 2.添加相应

Zabbix最佳实践二:快速入门

一.登录与配置用户 1.1 登陆 这是Zabbix的"欢迎"界面.输入用户名 Admin 以及密码 zabbix 以作为 Zabbix超级用户登陆. 登陆后,你将会在页面右下角看到"以管理员连接(Connected as Admin)".同时会获得访问配置(Configuration) 和 管理(Administration) 菜单的权限. 点击右上角的用户头像,将显示语言设置为中文. 1.2 增加用户 可以在管理(Administration) → 用户(User

[转]Expression Blend实例中文教程(8) - 动画设计快速入门StoryBoard

上一篇,介绍了Silverlight动画设计基础知识,Silverlight动画是基于时间线的,对于动画的实现,其实也就是对对象属性的修改过程. 而Silverlight动画分类两种类型,From/To/By 动画和关键帧动画. 对于Silverlight动画设计,StoryBoard是非常重要的一个功能,StoryBoard不仅仅可以对动画的管理,而且还可以对动画的细节进行控制,例如控制动画的播放,暂停,停止以及跳转动画位置等. 为了简化开发人员和设计人员的设计过程,Blend提供了专门的工具

Expression Blend实例中文教程(7) - 动画基础快速入门Animation

通过前面文章学习,已经对Blend的开发界面,以及控件有了初步的认识.本文将讲述Blend的一个核心功能,动画设计.大家也许注意到,从开篇到现在,所有的文章都是属于快速入门,是因为这些文章,都是我曾经学习的经验和工作中使用到的经验总结出来的.在我个人认为,掌握了这些核心功能也就等于掌握了Blend的开发方法.在以后开发项目中使用Blend开发工具,这些知识应该足够用了.当然,特殊项目也需要特殊对待,如果您在项目开发中,有新的Blend开发经验,希望您能够毫不吝啬的分享,在这里,我表示深深的谢意.

利用python 数据分析入门,详细教程,教小白快速入门

这是一篇的数据的分析的典型案列,本人也是经历一次从无到有的过程,倍感珍惜,所以将其详细的记录下来,用来帮助后来者快速入门! 数据的格式如下: 我们设定 一个trem or  typedef为一条标签,一行为一条记录或者是键值对,以此为标准! 下面我们来对数据进行分析: 数据集中一共包含两种标签[trem] and [typedef]两种标签,每个标签下边有多个键值对,和唯一的标识符id,每行记录以"/n"结尾,且每条标签下下有多个相同的键值对,for examble: is_a,syn