使用tensorflow api生成one-hot标签数据

转自:http://www.terrylmay.com/2017/06/generate-one-hot-data/

使用tensorflow api生成one-hot标签数据


在刚开始学习tensorflow的时候, 会有一个最简单的手写字符识别的程序供新手开始学习,
在tensorflow.example.tutorial.mnist中已经定义好了mnist的训练数据以及测试数据.
并且标签已经从原来的List变成了one-hot的二维矩阵的格式.看了源码的就知道mnist.input_data.read_data()这个方法中使用的是numpy中的方法来实现标签的one-hot矩阵化。那么如何使用tensorflow中自带的api来实现呢?下面我们就来一起看一下需要用到的api吧。

tf.expand_dims 方法
这个函数主要给矩阵或者数组增加一维°, 看代码可能更加清晰:

import tensorflow as tf
# 比如现在有一个列表
x_data = [1, 2, 3, 4]
x_data_expand = tf.expand_dims(x_data, 0) # x_data的shape是[4], 该函数表示在最前面的位置增加一维, 就会变成[1, 4]
# 而对于[1, 4] 的矩阵加上x_data本身的数据, 那么可以猜想到x_data_expand = [[1, 2, 3, 4]]

x_data_expand_axis1 = tf.expand_dims(x_data, axis=1) # x_data的shape是[4], 而axis=1表示在本来的矩阵的第1列加一维, 所以x_data_expand_axis1是[4, 1] 4行一列的矩阵, 并且把原始数据套进去可知: x_data_expand_axis1 = [[1], [2], [3], [4]], 但是这个axis的参数值不能大于矩阵的列数, 比如矩阵shape为[1, 2, 3] 那么axis=0 则会生成[1, 1, 2, 3], axis=1则会生成[1, 1, 2, 3], axis=2则会生成[1, 2, 1, 3], axis=3则会生成[1, 2, 3, 1]。就是在某一个位置插入一列

tf.concat(values, axis)
该函数用于将两个相同维度的数据进行合并, 如果指定axis=0那么只需要列数相同即可.否则需要维度都相同 看如下代码:

import tensorflow as tf
x_data = [[1, 2, 3], [4, 5, 6]]
y_data = [[7, 8, 9], [10, 11, 12]]

concat_result = tf.concat(values=[x_data, y_data], axis=0) # 这样的话, 生成的数据是[[1, 2, 3], [7, 8, 9], [4, 5, 6], [10, 11, 12]]
concat_result = tf.concat(values=[x_data, y_data], axis=1) # 这样的话, 生成的数据是[[ 1  2  3  7  8  9], [ 4  5  6 10 11 12]], 三维的甚至更高维度的数据稍后再尝试

tf.sparse_to_dense()
def sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
default_value=0,
validate_indices=True,
name=None):
该函数指定位置赋值, 并且生成一个维度为output_shape的矩阵;如果output_shape维度为1, 那么sparse_indices只能是一个列表, 如果output_shape为二维矩阵, 那么sparse_indices就可以是矩阵了.

比如如下代码:

import tensorflow as tf

sparse_indices = [1, 2, 6]
output_shape = tf.zeros([10]).shape
sparse_output = tf.sparse_to_dense(sparse_indices, output_shape, 2, default_value=0) # 生成的结果为:sparse_output:[0 2 2 0 0 0 2 0 0 0] 就是在位置1, 2, 6的位置填充2 其余位置填充0

# 对于二维矩阵的填充也是一样的, 比如:
sparse_indices = [[0, 1], [2, 4], [4 ,5], [6, 9]]
output_shape = tf.zeros([6, 10]).shape
sparse_output = tf.sparse_to_dense(sparse_indices, output_shape, 1, default_value=0) #生成的数据如下:# 生成的数据如下:sparse_output:
[[0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]]

下面开始实现对原始标签列表的one-hot化

  import tensorflow as tf

  labels  = [1, 3, 4, 8, 7, 5, 2, 9, 0, 8, 7]
  labels_expand = tf.expand_dims(labels, axis=1) # 这样label_expand为[11, 1]的数据

  index_expand = tf.expand_dims(tf.range(len(labels)), axis=1) # 与label_expand中的元素一一对应

  concat_result = tf.concat(values=[index_expand, labels_expand], axis=1) # 将上述两组数据组合在一起

  one_hot = tf.sparse_to_dense(sparse_indices=concat_result, output_shape=tf.zeros([len(labels), 10]).shape, sparse_values=1.0, default_value=0.0)

  session = tf.InteractiveSession()

  print(‘labels_expand:{}‘.format(session.run(labels_expand)))
  print(‘index_expand:{}‘.format(session.run(index_expand)))

  print(‘concat_result:{}‘.format(session.run(concat_result)))
  print(‘one_hot_of_labels:{}‘.format(session.run(one_hot)))

最后的结果如下打印:
python
labels_expand:[[1]
[3]
[4]
[8]
[7]
[5]
[2]
[9]
[0]
[8]
[7]]
index_expand:[[ 0]
[ 1]
[ 2]
[ 3]
[ 4]
[ 5]
[ 6]
[ 7]
[ 8]
[ 9]
[10]]
concat_result:[[ 0 1]
[ 1 3]
[ 2 4]
[ 3 8]
[ 4 7]
[ 5 5]
[ 6 2]
[ 7 9]
[ 8 0]
[ 9 8]
[10 7]]
one_hot_of_labels:[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[ 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]

这样就实现了labels的one-hot化。

使用Numpy来实现Label的one-hot化

    import numpy as np

    labels  = [1, 3, 4, 8, 7, 5, 2, 9, 0, 8, 7]
    one_hot_index = np.arange(len(labels)) * 10 + labels

    print (‘one_hot_index:{}‘.format(one_hot_index))

    one_hot = np.zeros((len(labels), 10))
    one_hot.flat[one_hot_index] = 1

    print(‘one_hot:{}‘.format(one_hot))

原文地址:https://www.cnblogs.com/shyzh/p/10982356.html

时间: 2024-07-29 16:43:31

使用tensorflow api生成one-hot标签数据的相关文章

一个PHP脚本,通过curl先获取百度地图api生成的经纬度,然后改数据库内的数据。

今天写一个PHP脚本,目的是让先从数据库拿取响应的地区名  然后通过幼儿园的名字来查询准确的经纬度.此间每次生成的经纬度进入数据库内的更改. 7万多条数据用时一个小时执行完毕. 不得不说 用curl结果还是比file_getcoents快的多.话不多说直接上代码 <?php date_default_timezone_set('Asia/Chongqing'); header('content-type:text/html; charset=utf-8'); ini_set('display_e

C# API: 生成和读取Excel文件

我们想为用户提供一些数据,考虑再三, 大家认为对于用户(人,而非机器)的可读性, Excel文件要好一些. 因为相比csv,xml等文件, Excel中我们可以运用自动筛选, 窗口锁定, 还可以控制背景颜色, 前景颜色, 字体, 网格等等... 业务逻辑并不复杂, 文件的内容和格式也比较固定,所以大家决定直接拿C#去创建这些文件. 于是一搜索,首先来到了这个链接:C# Excel Tutorial 里面包含了下面这些主题的代码示例, 示例很详细, 编译可直接运行. How to create E

后台动态生成静态select标签的option项

以下为代码示例: <select id="Category_<%#Eval("BTUserID") %>" name="Category_<%#Eval("BTUserID") %>" disabled on onchange=setHourlyCost("<%#Eval("BTUserID") %>") style="width:20

毕业设计---jQuery动态生成的a标签的事件绑定

这几天在毕业设计的前端设计阶段,准备放弃使用jsp,完全通过html+ajax+SSH进行网站的编写,在前端的页面显示我准备使用jQuery来实现数据的动态绑定.但是遇到动态添加的a标签无法直接通过$(element).click();来添加点击事件,通过网上的查询,在动态添加的标签绑定事件需要通过事件委托而非事件绑定. $("body").on("click", ".delete", function (){ del($(this).paren

利用google api生成二维码名片

利用google api生成二维码名片 二维条码/二维码可以分为堆叠式/行排式二维条码和矩阵式二维条码.堆叠式/行排式二维条码形态上是由多行短截的一维条码堆叠而成:矩阵式二维条码以矩阵的形式组成,在矩阵相应元素位置上用“点”表示二进制“1”,用“空”表示二进制“0”,“点”和“空”的排列组成代码. 堆叠式/行排式二维条码,如,Code 16K.Code 49.PDF417等. 矩阵式二维码,最流行莫过于QR CODE. 矩阵式二维码存储的数据量更大:可以包含数字.字符,及中文文本等混合内容:有一

如何用 Python 和 API 收集与分析网络数据?

摘自 https://www.jianshu.com/p/d52020f0c247 本文以一款阿里云市场历史天气查询产品为例,为你逐步介绍如何用 Python 调用 API 收集.分析与可视化数据.希望你举一反三,轻松应对今后的 API 数据收集与分析任务. 市场 我们尝试的,是他们找到的阿里云市场的一款 API 产品,提供天气数据. 它来自于易源数据,链接在 https://market.aliyun.com/products/57096001/cmapi010812.html?spm=517

Matlab生成二类线性可分数据

%% 生成二类线性可分数据 function [feature, category]=generate_sample(step,error) aa=3; %斜率 bb=3; %截距 b1=1; rr =error; s=step; x1(:,1) = -1:s:1; n = length(x1(:,1)); x1(:,2) = aa.*x1(:,1) + bb + b1 + rr*abs(randn(n,1)); y1 = -ones(n,1); x2(:,1) = -1:s:1; x2(:,2

mahout贝叶斯算法拓展篇3---分类无标签数据

代码测试环境:Hadoop2.4+Mahout1.0 前面博客:mahout贝叶斯算法开发思路(拓展篇)1和mahout贝叶斯算法开发思路(拓展篇)2 分析了Mahout中贝叶斯算法针对数值型数据的处理.在前面这两篇博客中并没有关于如何分类不带标签的原始数据的处理.下面这篇博客就针对这样的数据进行处理. 最新版(适合Hadoop2.4+mahout1.0环境)源码以及jar包可以在这里下载Mahout贝叶斯分类不含标签数据: 下载后参考使用里面的jar包中的fz.bayes.model.Baye

fastjson生成和解析json数据,序列化和反序列化数据

本文讲解2点: 1. fastjson生成和解析json数据 (举例:4种常用类型:JavaBean,List<JavaBean>,List<String>,List<Map<String,Object>) 2.通过一个android程序测试fastjson的用法. fastjson简介: Fastjson是一个Java语言编写的高性能功能完善的JSON库.fastjson采用独创的算法,将parse的速度提升到极致,超过所有json库,包括曾经号称最快的jack