Tensorflow - tf.split使用

XDeepFM的CIN中第一层实现需要使两个二维矩阵相乘得到一个三维张量,于是来复习下split函数(需要用到):
首先看下函数原理:

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name=‘split‘
)

这个函数是用来切割张量的:输入切割的张量和参数,返回切割的结果。
value传入的就是需要切割的张量,axis的数值代表切割哪个维度。
这个函数有两种切割的方式:

以三个维度的张量为例,比如说一个20 * 30 * 40的张量my_tensor,就如同一个长20厘米宽30厘米高40厘米的蛋糕,每立方厘米都是一个分量。

有两种切割方式:
1. 如果num_or_size_splits传入的是一个整数,这个整数代表这个张量最后会被切成几个小张量。此时,传入axis的数值就代表切割哪个维度(从0开始计数)。调用tf.split(my_tensor, 2,0)返回两个10 * 30 * 40的小张量。
2. 如果num_or_size_splits传入的是一个向量,那么向量有几个分量就分成几份,切割的维度还是由axis决定。比如调用tf.split(my_tensor, [10, 5, 25], 2),则返回三个张量分别大小为 20 * 30 * 10、20 * 30 * 5、20 * 30 * 25。很显然,传入的这个向量各个分量加和必须等于axis所指示原张量维度的大小 (10 + 5 + 25 = 40)。

一个实例:

import tensorflow as tf
import numpy as np

arr1 = tf.convert_to_tensor(np.arange(1,25).reshape(2,4,3),dtype=tf.int32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    split_arr1 = tf.split(arr1,[1,1,1],2) # 切割成2个2*4*1的张量
   print(sess.run(split_arr1)

可以看到原来的2*4*3的张量被切割为了3个2*4*1的张量

Reference:

https://blog.csdn.net/SangrealLilith/article/details/80272346

原文地址:https://www.cnblogs.com/Jesee/p/11277868.html

时间: 2024-10-12 22:34:47

Tensorflow - tf.split使用的相关文章

TensorFlow 从入门到精通(八):TensorFlow tf.nn.conv2d 一路追查

读者可能还记得本系列博客(二)和(六)中 tf.nn 模块,其中最关心的是 conv2d 这个函数. 首先将博客(二) MNIST 例程中 convolutional.py 关键源码列出: def model(data, train=False): """The Model definition.""" # 2D convolution, with 'SAME' padding (i.e. the output feature map has #

TensorFlow tf.estimator package not installed

在使用 pip install tensorflow 命令安装TensorFlow,在成功安装后,在 import tensorflow是出现 "tf.estimator package not installed" 解决方法如下: 1.确保 pandas, numpy, matplotlib 这些依赖包已经被正确安装 2.使用 pip install -U xxx --no-cache-dir (不使用缓存文件,重新网上下载安装) 3.如果更新相关依赖包后还没解决话的,就需要将 pa

TensorFlow tf.gradients的用法详细解析以及具体例子

tf.gradients 官方定义: tf.gradients( ys, xs, grad_ys=None, name='gradients', stop_gradients=None, ) Constructs symbolic derivatives of sum of ys w.r.t. x in xs. ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gr

Tensorflow tf.app.flags 的使用

在执行main函数之前首先进行flags的解析,也就是说TensorFlow通过设置flags来传递tf.app.run()所需要的参数,我们可以直接在程序运行前初始化flags,也可以在运行程序的时候设置命令行参数来达到传参的目的. 下面是一个小demo import tensorflow as tf flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string("name", "x1aolata", &

TensorFlow 实战(二)—— tf train(优化算法)

Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class)主要实现了两个接口,一是计算损失函数的梯度,二是将梯度作用于变量.tf.train 主要提供了如下的优化函数: tf.train.Optimizer tf.train.GradientDescentOptimizer tf.train.AdadeltaOpzimizer Ada delta tf.

TensorFlow分布式计算机制解读:以数据并行为重

Tensorflow 是一个为数值计算(最常见的是训练神经网络)设计的流行开源库.在这个框架中,计算流程通过数据流程图(data flow graph)设计,这为更改操作结构与安置提供了很大灵活性.TensorFlow 允许多个 worker 并行计算,这对必须通过处理的大量训练数据训练的神经网络是有益的.此外,如果模型足够大,这种并行化有时可能是必须的.在本文中,我们将探讨 TensorFlow 的分布式计算机制. TensorFlow 计算图示例 数据并行 VS. 模型并行 当在多个计算节点

Tensorflow - Tutorial (7) : 利用 RNN/LSTM 进行手写数字识别

1. 经常使用类 class tf.contrib.rnn.BasicLSTMCell BasicLSTMCell 是最简单的一个LSTM类.没有实现clipping,projection layer.peep-hole等一些LSTM的高级变种,仅作为一个主要的basicline结构存在,假设要使用这些高级变种,需用class tf.contrib.rnn.LSTMCell这个类. 使用方式: lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0

『TensorFlow』0.x_&_1.x版本框架改动汇总

基本数值运算 除法和模运算符(/,//,%)现在匹配 Python(flooring)语义.这也适用于 [tf.div] 和 [tf.mod].要获取基于强制整数截断的行为,可以使用 [tf.truncatediv] 和 [tf.truncatemod]. 现在推荐使用 [tf.divide()] 作为除法函数.[tf.div()] 将保留,但它的语义不会回应 Python 3 或 [from future] 机制 [tf.mul,tf.sub ] 和 [tf.neg] 不再使用,改为 [tf.

如何基于TensorFlow使用LSTM和CNN实现时序分类任务

https://www.jiqizhixin.com/articles/2017-09-12-5 By 蒋思源2017年9月12日 09:54 时序数据经常出现在很多领域中,如金融.信号处理.语音识别和医药.传统的时序问题通常首先需要人力进行特征工程,才能将预处理的数据输入到机器学习算法中.并且这种特征工程通常需要一些特定领域内的专业知识,因此也就更进一步加大了预处理成本.例如信号处理(即 EEG 信号分类),特征工程可能就涉及到各种频带的功率谱(power spectra).Hjorth 参数