PyTorch中scatter和gather的用法

PyTorch中scatter和gather的用法

闲扯

许久没有更新博客了,2019年总体上看是荒废的,没有做出什么东西,明年春天就要开始准备实习了,虽然不找算法岗的工作,但是还是准备在2019年的最后一个半月认真整理一下自己学习的机器学习和深度学习的知识。

scatter的用法

scatter中文翻译为散射,首先看一个例子来直观感受一下这个API的功能,使用pytorch官网提供的例子。

import torch
import torch.nn as nn
x = torch.rand(2,5)
x
tensor([[0.2656, 0.5364, 0.8568, 0.5845, 0.2289],
        [0.0010, 0.8101, 0.5491, 0.6514, 0.7295]])
y = torch.zeros(3,5)
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
index
tensor([[0, 1, 2, 0, 0],
        [2, 0, 0, 1, 2]])
y.scatter_(dim=0,index=index,src=x)
y
tensor([[0.2656, 0.8101, 0.5491, 0.5845, 0.2289],
        [0.0000, 0.5364, 0.0000, 0.6514, 0.0000],
        [0.0010, 0.0000, 0.8568, 0.0000, 0.7295]])

首先我们可以看到,x的所有值都在y中出现了,且被索引的轴为dim=0,任意一个来自x中的元素,将按照以下公式完成映射。
y[index[i,j],j] = x[i,j],对于x[0,1] = 0.5364,index[0,1] = 1指出这个值将出现在y的第dim=0维,下标为1的位置,因此,y[index[0,1],1] = y[1,1] = x[0,1] = 0.5364.

到这里我们已经对scatter,即散射这个函数有了直观的认识,可用于将一个矩阵映射到一个矩阵,dim指明要映射的轴,index指明要映射的轴的下标,因此对于3D张量,若调用y.scatter_(dim,index,src),那么有:

y[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
y[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
y[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

最后看一个官方文档的关于scatter的英文说明:

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

意思和直观感受几乎相同,函数可将src映射到目标张量self,在dim维度上,由索引index给出下标,在非dim维度上,直接使用src值所在位置的下标。

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

显然self,index,src的ndim应该相同了,否则下标越界了,从公式上看index.size(d) > src.size(d),index.size(d) > self.size(d)没什么问题,index数组可以比src更大,猜测这里是工程上的考虑,因为超出src大小的index数组在这里是没用的,闲置的空间不会被访问。

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.

index所有的值需要在[0,self.size(dim) - 1]区间内,这是必须满足的,否则越界了。第二句说沿着dim维的index的所有值需要是唯一的,我测试的结果发现可以重复,看下面的代码:

x = torch.rand(2,5)
x
tensor([[0.6542, 0.6071, 0.7546, 0.4880, 0.1077],
        [0.9535, 0.0992, 0.0594, 0.0641, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.6542, 0.0992, 0.0594, 0.4880, 0.1077],
        [0.0000, 0.6071, 0.0000, 0.0641, 0.0000],
        [0.9535, 0.0000, 0.7546, 0.0000, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[0,1,2,0,0]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.9535, 0.0000, 0.0000, 0.0641, 0.7563],
        [0.0000, 0.0992, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0594, 0.0000, 0.0000]])

可以看到沿着dim=0轴上重复了5次,分别是(0,0),(1,1),(2,2),(0,0),(0,0),代码无报错和警告,只是覆盖掉了原来的值,可能是文档没有更新,但是API更新了。

params:

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified

值得注意的是value参数,当没有指明src时,可以指定一个浮点value变量,利用这一点我们实现一个scatter版本的onehot函数。

x = torch.tensor([[1,1,1,1,1]],dtype=torch.float32)
index = torch.tensor([[0,1,2,3,4]],dtype=torch.int64)
y = torch.zeros(5,5,dtype=torch.float32)
x
tensor([[1., 1., 1., 1., 1.]])
y.scatter_(0,index,x)
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
y = torch.zeros(5,5,dtype=torch.float32)
y.scatter_(0,index,1)
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

可以看到指定value=1,和src=[[1,1,1,1,1]]等价。到这里scatter就结束了。

gather的用法

gather是scatter的逆过程,从一个张量收集数据,到另一个张量,看一个例子有个直观感受。

x = torch.tensor([[1,2],[3,4]])
torch.gather(input=x,dim=1,index=torch.tensor([[0,0],[1,0]]))
tensor([[1, 1],
        [4, 3]])

可以猜测到收集过程,根据index和dim将x中的数据挑选出来,放置到y中,满足下面的公式:
y[i,j] = x[i,index[i,j]],因此有y[0,0] = x[0,index[0,0]] = x[0,0] = 1, y[1,0] = x[1,index[1,0]] = x[1,1] = 4,对于3D数据,满足以下公式:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

到这里gather的用法介绍就结束了,因为gather毕竟是scatter的逆过程,理解了scatter,gather并不需要太多说明。

小结

  1. scatter可以将一个张量映射到另一个张量,其中一个应用是onehot函数.
  2. gather和scatter是两个互逆的过程,gather可用于压缩稀疏张量,收集稀疏张量中非0的元素。
  3. 别再荒废时光了,做不出成果也不能全怪自己的。

原文地址:https://www.cnblogs.com/liuzhan709/p/11875743.html

时间: 2024-07-31 02:21:39

PyTorch中scatter和gather的用法的相关文章

[PyTorch]PyTorch中反卷积的用法

文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=True) class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_si

pytorch中如何处理RNN输入变长序列padding

一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样. 比如向下图这样: 但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就

关于Java中this和super的用法介绍和区别

1.this&super 什么是this,this是自身的一个对象,代表对象本身,可以理解为:指向对象本身的一个指针.当你想要引用当前对象的某种东西,比如当前对象的某个方法,或当前对象的某个成员,你便可以利用this来实现这个目的.要注意的是this只能在类中的非静态方法中使用,静态方法和静态的代码块中绝对不能出现this.his也可作为构造函数来使用.在后面可以看到 而什么是super,可以理解为是指向自己超(父)类对象的一个指针,而这个超类指的是离自己最近的一个父类.super的作用同样是可

shell中$0,$?,$!等的特殊用法

shell中$0,$?,$!等的特殊用法 变量说明: $$Shell本身的PID(ProcessID)$!Shell最后运行的后台Process的PID$?最后运行的命令的结束代码(返回值)$-使用Set命令设定的Flag一览$*所有参数列表.如"$*"用「"」括起来的情况.以"$1 $2 … $n"的形式输出所有参数.[email protected]所有参数列表.如"[email protected]"用「"」括起来的情况

Oracle中HINT的30个用法

在SQL语句优化过程中,经常会用到hint, 以下是在SQL优化过程中常见Oracle中"HINT"的30个用法 1. /*+ALL_ROWS*/ 表明对语句块选择基于开销的优化方法,并获得最佳吞吐量,使资源消耗最小化. 例如: SELECT /*+ALL+_ROWS*/ EMP_NO,EMP_NAM,DAT_IN FROM BSEMPMS WHERE EMP_NO='SCOTT'; 2. /*+FIRST_ROWS*/ 表明对语句块选择基于开销的优化方法,并获得最佳响应时间,使资源消

js中继承的几种用法总结(apply,call,prototype)

本篇文章主要介绍了js中继承的几种用法总结(apply,call,prototype) 需要的朋友可以过来参考下,希望对大家有所帮助 一,js中对象继承 js中有三种继承方式 1.js原型(prototype)实现继承 <SPAN style="<SPAN style="FONT-SIZE: 18px"><html>   <body>  <script type="text/javascript"> 

C中的时间函数的用法

C中的时间函数的用法    这个类展示了C语言中的时间函数的常用的用法. 源代码: #include <ctime>#include <iostream> using namespace std; class MyTime{public:    MyTime() { mPTime = 0; mStLocalTime = 0; mStGMTTime = 0; }    ~MyTime() {}; //time_t time(time_t * timer) 返回自1970年1月1日00

java中静态代码块的用法 static用法详解

(一)java 静态代码块 静态方法区别一般情况下,如果有些代码必须在项目启动的时候就执行的时候,需要使用静态代码块,这种代码是主动执行的;需要在项目启动的时候就初始化,在不创建对象的情况下,其他程序来调用的时候,需要使用静态方法,这种代码是被动执行的. 静态方法在类加载的时候 就已经加载 可以用类名直接调用比如main方法就必须是静态的 这是程序入口两者的区别就是:静态代码块是自动执行的;静态方法是被调用的时候才执行的.静态方法(1)在Java里,可以定义一个不需要创建对象的方法,这种方法就是

C# 中 PadLeft和PadRight 的用法

在 C# 中可以对字符串使用 PadLeft 和 PadRight 进行轻松地补位. PadLeft(int totalWidth, char paddingChar) //在字符串左边用 paddingChar 补足 totalWidth 长度 PadRight(int totalWidth, char paddingChar) //在字符串右边用 paddingChar 补足 totalWidth 长度 示例: 1.假如想输出AAAAA,可以用string.Empty.PadLeft(5,'