pytorch中的scatter_()函数

最近在学习pytorch函数时需要做独热码,然后遇到了scatter_()函数,不太明白意思,现在懂了记录一下以免以后忘记。

这个函数是用一个src的源张量或者标量以及索引来修改另一个张量。这个函数主要有三个参数scatter_(dim,index,src)

dim:沿着哪个维度来进行索引(一会儿举个例子就明白了)

index:用来进行索引的张量

src:源张量或者标量

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

这个是官网给出的例子,但是一般在做独热码的时候通常是采用二维张量所以应该是这样

#dim=0
self[index[x][y]][y]=src[x][y]  

#dim=1
self[x][index[x][y]]=src[x][y]

这个是什么意思呢。首先请看下面的程序,程序是我瞎编的,想试试的话可以自己编数据哈

import torch
x=torch.rand(3,5)
print(x)
print(‘-------------------‘)
y=torch.zeros(3,5)
print(y)
print(‘-------------------‘)
inx=torch.tensor([[0,4,3,1,2],[3,2,1,4,3]])
output_y=y.scatter_(dim=1,index=inx,src=x)
print(output_y)

下面是运行的结果

tensor([[0.1380, 0.6030, 0.2396, 0.0066, 0.7116],
        [0.5755, 0.2856, 0.4862, 0.2132, 0.2475],
        [0.5145, 0.4753, 0.2736, 0.2623, 0.8532]])
-------------------
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
-------------------
tensor([[0.1380, 0.0066, 0.7116, 0.2396, 0.6030],
        [0.0000, 0.4862, 0.2856, 0.2475, 0.2132],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

Process finished with exit code 0

那么是什么意思呢,举个例子,这里我要强调一下,index即这个程序中的inx里面的每个数值,不能超过该dim的张量的最大下标,不然的话就会越界,找不到src中的源数据。因为是在dim=1上进行索引,所以采用第二个式子。

我们在索引表中找到index[1][3]=4,那么此时x=1,y=3,即output_y[1][index[1][3]]=src[1][3],即output_y[1][4]=src[1][3]。即x[1][3]。以此类推就可以得到其他的值。

src不仅仅可以是张量,也可以是标量,下面这段代码是模仿怎么生成独热码

import torch
x=torch.zeros(4,8)
label=torch.tensor([[1],[5],[7],[6]])
one_hot=x.scatter_(1,label,1)
print(one_hot)

其中x的第一个参数代表的是batch_size,第二个参数代表的是classnum,而label有batch_size行只有一列,就是将x每一行的label值指向的位置置成1,这就是独热码。当然其他位置都是0啦,下面看一下结果吧。

tensor([[0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.]])

Process finished with exit code 0

好啦,这就是scatter_()函数的用法。

ps:本来坚持不下去了快,但是把scatter弄清楚了发现还有一点动力学下去,加油吧。

原文地址:https://www.cnblogs.com/daremosiranaihana/p/12538512.html

时间: 2024-10-30 01:20:47

pytorch中的scatter_()函数的相关文章

Pytorch中的数学函数

log_softmax log(softmax(X)) function:torch.nn.functional.log_softmax(x, dim=None) nn:torch.nn.LogSoftmax(dim=None) 如: nll_loss The negative log likelihood loss function:torch.nn.functional.nll_loss(input, target, weight=None, size_average=True, ignor

PyTorch中scatter和gather的用法

PyTorch中scatter和gather的用法 闲扯 许久没有更新博客了,2019年总体上看是荒废的,没有做出什么东西,明年春天就要开始准备实习了,虽然不找算法岗的工作,但是还是准备在2019年的最后一个半月认真整理一下自己学习的机器学习和深度学习的知识. scatter的用法 scatter中文翻译为散射,首先看一个例子来直观感受一下这个API的功能,使用pytorch官网提供的例子. import torch import torch.nn as nn x = torch.rand(2,

[PyTorch]PyTorch中反卷积的用法

文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=True) class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_si

[Pytorch]Pytorch中tensor常用语法

原文地址:https://zhuanlan.zhihu.com/p/31494491 上次我总结了在PyTorch中建立随机数Tensor的多种方法的区别. 这次我把常用的Tensor的数学运算总结到这里,以防自己在使用PyTorch做实验时,忘记这些方法应该传什么参数. 总结的方法包括: Tensor求和以及按索引求和:torch.sum() torch.Tensor.indexadd() Tensor元素乘积:torch.prod(input) 对Tensor求均值.方差.极值: torch

pytorch中如何处理RNN输入变长序列padding

一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样. 比如向下图这样: 但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就

Pytorch中的自编码(autoencoder)

Pytorch中的自编码(autoencoder) 本文资料来源:https://www.bilibili.com/video/av15997678/?p=25 什么是自编码 先压缩原数据.提取出最有代表性的信息.然后处理后再进行解压.减少处理压力 通过对比白色X和黑色X的区别(cost函数),从而不断提升自编码模型的能力(也就是还原的准确度) 由于这里只是使用了数据本身,没有使用label,所以可以说autoencoder是一种无监督学习模型. 实际在使用中,我们先训练好一个autoencod

检测某个方法是否属于某个类中--解析php函数method_exists()与is_callable()的区别

php函数method_exists() 与is_callable()的区别在哪?在php面相对象设计过程中,往往我们需要在调用某一个方法是否属于某一个类的时候做出判断,常用的方法有 method_exists()和is_callable() 相比之下,is_callable()函数要高级一些,它接受字符串变量形式的方法名作为 第一个参数,如果类方法存在并且可以调用,则返回true.如果要检测类中的方法是否能被调用,可以给函数传递一个数组而不是类的方法名作为参数.数组必须包含对象或类名,以将其作

delphi中的Format函数详解

首先看它的声明:[[email protected]][@21ki!] function Format(const Format: string; const Args: array of const): string; overload;[[email protected]][@21ki!] 事实上Format方法有两种形式,另外一种是三个参数的,主要区别在于它是线程安全的,[[email protected]][@21ki!]但并不多用,所以这里只对第一个介绍:[[email protect

jquery中的 $(function(){ .. }) 函数

2017-04-29 在讲解jquery中的 $(function(){ .. }) 函数之前,我们先简单了解下匿名函数.匿名函数的形式为:(function(){ ... }),又如 function(arg){ ... };定义了 一个参数为 arg 的匿名函数,然后使用 (function(arg){ ... })(param) 来调用这个函数,其中 param 是传入这个匿名函数的参数. 但需要主要匿名函数与jquery中的 $(function(){ ...}) 函数的区别:$(fun