tensor维度变换

维度变换是tensorflow中的重要模块之一,前面mnist实战模块我们使用了图片数据的压平操作,它就是维度变换的应用之一。

在详解维度变换的方法之前,这里先介绍一下View(视图)的概念。所谓View,简单的可以理解成我们对一个tensor不同维度关系的认识。举个例子,一个[ b,28,28,1 ]的tensor(可以理解为mnist数据集的一组图片),对于这样一组图片,我们可以有一下几种理解方式:

(1)按照物理设备储存结构,即一整行的方式(28*28)储存,这一行有连续的784个数据,这种理解方式可以用[ b,28*28 ]表示

(2)按照图片原有结构储存,即保留图片的行列关系,以28行28列的数据理解,这种方式可以用[ b,28,28 ]表示

(3)将图片分块(比如上下两部分),这种理解方式与第二种类似,只是将一张图变为两张,这种方式可以用[ b,2,14*28 ]表示

(4)增加channel通道,这种理解方式也与第二种类似,只是这种对rgb三色图区别更明显,可以用[ b,28 28,1 ]表示

通过维度的等价变换,就可以实现思维上View的转换

维度变换的方式:

方式1:tf.reshape(可通过破坏维度之间的关系改变tensor的维度,但不会改变原有数据的存储顺序)

a = tf.random.normal([4,28,28,3])
print(a.shape)
print(tf.reshape(a,[4,784,3]).shape)
print(tf.reshape(a,[4,-1,3]).shape)
print(tf.reshape(a,[4,784*3]).shape)
print(tf.reshape(a,[4,-1]).shape)

但是reshape在恢复已经reshape的数据时会出现问题,比如[ 4,28,28,3 ]的数据reshape成[ 4,784,3 ]的数据要想再恢复成以前的样子,就需要记录下以前的content(内容)信息,如果记录过程出现错误(如width和height维度记反或者数值记错),就会导致恢复不成想要的样子。

方式2:tf.transpose  (content的变换)

a = tf.random.normal([4,3,2,1])
print(a.shape)
print(tf.transpose(a).shape)
print(tf.transpose(a,perm=[0,1,3,2]).shape)

通过这种变换方式会彻底改变原来图片数据的维度关系,在经过transpose之后,再用reshape变换得到的数据是基于新的content(transpose之后)进行的变换,所以reshape时要记录新的content信息,不然会导致数据混乱甚至程序异常。

方式3:tf.expand_dims、tf.squeeze (增加和减少维度)

a = tf.random.normal([4,35,8])
# tf.expand_dims增加维度
# 若给定axis>0,则在给定轴前增加维度,若给定axis<0,则在给定轴后增加维度
print(tf.expand_dims(a,axis=0).shape)
print(tf.expand_dims(a,axis=3).shape)
print(tf.expand_dims(a,axis=-1).shape)
print(tf.expand_dims(a,axis=-4).shape)

# tf.squeeze用于减少维度
print(tf.squeeze(tf.zeros([1,2,1,1,3])).shape)
a = tf.zeros([1,2,1,3])
print(tf.squeeze(a,axis=0).shape)
print(tf.squeeze(a,axis=2).shape)
print(tf.squeeze(a,axis=-2).shape)
print(tf.squeeze(a,axis=-4).shape)

需要注意的是,squeeze只能减少维度值为1的维度,且axis必须为已存在的轴索引

当前主流的神经网络之一SE-NET就通过巧妙的使用expand和squeeze模块,使得模型准确率更上一个台阶

SE-net的github源码地址:https://github.com/hujie-frank/SENet

原文地址:https://www.cnblogs.com/zdm-code/p/12208146.html

时间: 2024-11-05 20:34:35

tensor维度变换的相关文章

PyTorch中Tensor的维度变换实现

对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看. 维度查看:torch.Tensor.size() 查看当前 tensor 的维度 举个例子: >>> import torch >>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]) >>> a.size() torch.Size

pytorch张量数据索引切片与维度变换操作大全(非常全)

(1-1)pytorch张量数据的索引与切片操作1.对于张量数据的索引操作主要有以下几种方式:a=torch.rand(4,3,28,28):DIM=4的张量数据a(1)a[:2]:取第一个维度的前2个维度数据(不包括2):(2)a[:2,:1,:,:]:取第一个维度的前两个数据,取第2个维度的前1个数据,后两个维度全都取到:(3)a[:2,1:,:,:]:取第一个维度的前两个数据,取第2个维度的第1个索引到最后索引的数据(包含1),后两个维度全都取到:(4)a[:2,-3:]:负号表示第2个维

[TensorFlow]Tensor维度理解

http://wossoneri.github.io/2017/11/15/[Tensorflow]The-dimension-of-Tensor/ Tensor维度理解 Tensor在Tensorflow中是N维矩阵,所以涉及到Tensor的方法,也都是对矩阵的处理.由于是多维,在Tensorflow中Tensor的流动过程就涉及到升维降维,这篇就通过一些接口的使用,来体会Tensor的维度概念.以下是个人体会,有不准确的请指出. tf.reduce_mean reduce_mean( inp

pytorch 数据维度变换

1.view和reshape: * 两者功能完全一样.只是pytorch 0.3版本默认是view,为了和Numpy一致,后来增加了reshape的api * 注意:变换前后的数据大小必须一样 2.squeeze v.s. unsqueeze 1)unsqueeze,将维度展开. 如下图所示: 一维数据插入-1以后,将数据变成二维(1行2列,变为2行1列) 如果插入0,会将1行2列,变成1行1列(1列里面又是1行1列) 2)squeeze:删减维度 与unsqueeze作用刚好相反. a.squ

Tensorflow之Tensor形状变换和剪切组合

https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/slicing_and_joining

tensor的维度扩张的手段--Broadcasting

broadcasting是tensorflow中tensor维度扩张的最常用的手段,指对某一个维度上重复N多次,虽然它呈现数据已被扩张,但不会复制数据. 可以这样理解,对 [b,784]@[784,10]+[10]这样一个操作([10]可以理解为偏置项),那么原式可以化为[b,10]+[10],但是[b,10]和[10]这两个tensor是不能直接相加的,两者必须化为相一致维度的单元才能相加,即,把[10]扩张为[b,10],两者才能相加,而broadcasting做的就是这样一件事. 如果上面

ndarray数组变换

1 import numpy as np 维度变换 1 a = np.arange(24) 2 a array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]) reshape(),视图,不修改原数组 1 a.reshape(4,6) array([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11], [12, 13, 14, 15,

利用 TFLearn 快速搭建经典深度学习模型

利用 TFLearn 快速搭建经典深度学习模型 使用 TensorFlow 一个最大的好处是可以用各种运算符(Ops)灵活构建计算图,同时可以支持自定义运算符(见本公众号早期文章<TensorFlow 增加自定义运算符>).由于运算符的粒度较小,在构建深度学习模型时,代码写出来比较冗长,比如实现卷积层:5, 9 这种方式在设计较大模型时会比较麻烦,需要程序员徒手完成各个运算符之间的连接,像一些中间变量的维度变换.运算符参数选项.多个子网络连接处极易发生问题,肉眼检查也很难发现代码中潜伏的 bu

pytorch实现yolov3(3) 实现forward

之前的文章里https://www.cnblogs.com/sdu20112013/p/11099244.html实现了网络的各个layer. 本篇来实现网络的forward的过程. 定义网络 class Darknet(nn.Module): def __init__(self, cfgfile): super(Darknet, self).__init__() self.blocks = parse_cfg(cfgfile) self.net_info, self.module_list =