pytorch 的 sum 和 softmax 方法 dim 参数的使用

  在阅读使用 pytorch 实现的代码时,笔者会遇到需要对某一维数据进行求和( sum )或 softmax 的操作。在 pytorch 中,上述两个方法均带有一个指定维度的 dim 参数,这里记录下 dim 参数的用法。

  torch.sum

  在 pytorch 中,提供 torch.sum 的两种形式,一种直接将待求和数据作为参数,则返回参数数据所有维度所有元素的和,另外一种除接收待求和数据作为参数外,还可加入 dim 参数,指定对待求和数据的某一维进行求和。

    out = torch.sum( a )                   #对 a 中所有元素求和
    out = torch.sum( a , dim = 1 )         #对 a 中第 1 维的元素求和

  上述第一种形式比较好理解,但第二种形式,加入 dim 参数后,比较令人疑惑的是到底哪些元素参与了求和?这里通过例子来进行说明。

  1)首先我们生成一个维度为 ( 3, 4, 5, 6 ) 的元素全为 1.0 的 tensor a。

    >>> import torch
    >>> a = torch.ones( 3, 4, 5, 6 )         #生成一个形状为 ( 3, 4, 5, 6 ) 的数据,数据类型默认为 torch.FloatTensor

  2)使用 sum 方法对上述生成的 tensor 进行求和操作。注意 tensor 的维度索引从 0 开始。

    >>> b = torch.sum( a )               #对 a 中所有元素求和, b = 360.0
    >>> c = torch.sum( a, dim = 0 )      #对 a 中 dim = 0 元素求和
    >>> c.shape                          # c 的 shape 为 torch.Size( [ 4, 5, 6 ] ),其中所有元素值为 3.0
    >>> d = torch.sum( a, dim = 3 )      #对 a 中 dim = 3 元素求和
    >>> d.shape                          # d 的 shape 为 torch.Size( [ 3, 4, 5 ] ),其中所有元素值为 6.0

  对上述结果进行解释,b 的结果很好理解,因为 tensor a 的维度为 ( 3, 4, 5, 6 ) 且其中所有元素的值为 1,则对其中所有元素求和的结果为 3 * 4 * 5 * 6 * 1.0 = 360.0 .

  对于 c 和 d 的结果,首先可以观察得到的是, 若在第 i 维进行求和,即 sum( a, dim = i ),则求和结果的每一个元素的值均为该维度的大小。如在 dim = 0 求和,在 dim = 0 上 a 的尺寸为 3,则求和结果 c 的每一个元素值为 3.0 .也就是说每个结果元素值均为是三个求和元素值( 1.0 )相加的结果,求和结果 c 的维度为 ( 4, 5, 6 ),说明待求和数据 a 分为 ( 4, 5, 6 ) 共 4 * 5 * 6 组的元素进行了求和运算。在 dim = 3 上的求和结果 d 现象与 c 保持一致。

  对于输入待求和数据所有数据元素均为 1 时,可以归纳出一个结论,对于维度为 ( s0, s1, s2, s3 ) 的 tensor 的第 i 维进行求和,如第 0 维,则结果的维度为 ( s1, s2, s3 ),其维度为原输入维度去除求和维度。结果的每一个元素值即为 1 * s0 = s0,即为待求和维度的尺寸。

  下面以三维数据即维度为 ( 3, 4, 4 ) 的 tensor a 为例展示 sum 在某一维度的实际计算过程。

                      

  使用 dim = 0 参数计算时,产生的结果维度为 ( 4, 4 ), 对于结果中的每一个位置 ( i, j ) ,由 3 个元素进行计算,实际计算的是 a[ 0 ][ i ][ j ] + a[ 1 ][ i ][ j ] + a[ 2 ][ i ][ j ],当上述三个元素的值均为 1.0 时,计算结果元素即为 3.0 。如上图左侧的图,a[ 0 ][ 3 ][ 3 ] + a[ 1 ][ 3 ][ 3 ] + a[ 2 ][ 3 ][ 3 ] 的结果即为输出 ( 3, 3 ) 位置上的值。上述位置索引 ( i, j ) 的数量由输入的待求和数据的其他维度的尺寸决定。

  使用 dim = 2 参数计算时,产生的结果维度为 ( 3, 4 ),对于结果中的每一个位置( i, j ) ,由 4 个元素进行计算,实际计算的是 a[ i ][ j ][ 0 ] + a[ i ][ j ][ 1 ] + a[ i ][ j ][ 2 ] + a[ i ][ j ][ 3 ],当上述四个元素的值均为 1.0 时,计算结果元素即为 4.0 。如  a[ 0 ][ 0 ][ 0 ] + a[ 0 ][ 0 ][ 1 ] + a[ 0 ][ 0 ][ 2 ] + a[ 0 ][ 0 ][ 3 ] 即为输出 ( 0, 0 ) 位置上的值。

  对于维度为 ( s0, s1, s2, ... , si, ... , sn ) 的待求和向量,使用 dim = i 调用 sum 方法,则实际产生的结果维度为 ( s0, s1, s2, ... , si-1, si+1, ... , sn ),每个结果元素由 si 个元素元素求和获得。这 si 个元素坐标在其他维度索引保持一致,而在待求和维度索引由 0 至 si 变化。可以看到共有 ( s0, s1, s2, ... , si-1, si+1, ... , sn ) 组这样的求和元素( 索引的数量 ),即为结果的维度。

  torch.nn.softmax / torch.nn.functional.softmax

  softmax 是神经网路中常见的一种计算函数,其可将所有参与计算的对象值映射到 0 到 1 之间,并使得计算对象的和为 1. 在 pytorch 中的 softmax 方法在使用时也需要通过 dim 方法来指定具体进行 softmax 计算的维度。这里以 torch.nn.functional.softmax 为例进行说明。

  softmax 在 pytorch 官方文档中的描述如下:

  It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.  

  可以明确的是, softmax 计算获得的数值在 0 - 1 之间,但是同样比较令人疑惑的是,all slices along the dim 具体指代的是那些数据。这里使用一个维度为 ( 2, 2, 2 ) 的 tensor a 作为示例。

    >>> import torch
    >>> import torch.nn.functional as f
    >>> a = torch.ones( 2, 2, 2 )
    >>> b = f.softmax( a, dim=0 )           #对 a 的第 0 维进行 softmax 计算

  与 sum 方法不同,softmax 方法计算获得的结果的维度与输入的待计算的数据的维度保持一致( sum 方法求和后进行指定求和的那一维不会出现在结果维度中 )。

  参与 softmax 计算的元素与 sum 方法很相似,对于 tensor a 在 dim = 0 进行 softmax,输出结果 b 实际上是 b[ 0 ][ i ][ j ] + b[ 1 ][ i ][ j ] 的值为 1.即其他维度索引保持一致,而在进行 softmax 维度索引由 0 至 si 变化,如 b[ 0 ][ 0 ][ 1 ] + a[ 1 ][ 0 ][ 1 ] 的值为1.对于 tensor a 在 dim = 2 进行 softmax,输出结果 b 实际上是 b[ i ][ j ][ 0 ] + a[ i ][ j ][ 1 ] 的值为 1.

    >>> c = b[ 0 ][ 0 ][ 1 ] + b[ 1 ][ 0 ][ 1 ]        #c 的值为 1

  参考

  pytorch - tensor-creation-ops

  pytorch - torch.tensor

  pytorch - torch.nn.functional

原文地址:https://www.cnblogs.com/yhjoker/p/12701601.html

时间: 2024-11-13 04:53:15

pytorch 的 sum 和 softmax 方法 dim 参数的使用的相关文章

main方法的参数

敲例子的时候无意中把主方法的参数给落下了,当时没有发现,保存之后就去编译,运行了,通常情况下编译没有错误那胜利就在掌握之中了,没想到这次我竟然在"不一般"的行列中,编译无误,运行出错,错误信息如下: "找不到主方法?记得我写了main()方法了啊?回到代码处看了一下,也是static的啊?没问题啊,算了把错误信息拿出来与代码对照着看吧,发现唯一不一样的地方就是我的方法中没有参数, 立刻将参数添进去,编译,果然能运行了,回头想想,我也没传参啊,为什么还非得把它添进去啊?平时自己

C#中方法的参数的四种类型

C#中方法的参数有四种类型: 1. 值参数类型  (不加任何修饰符,是默认的类型) 2. 引用型参数  (以ref 修饰符声明) 3. 输出型参数  (以out 修饰符声明) 4. 数组型参数  (以params 修饰符声明) =================================================== 1. 值传递: 值类型是方法默认的参数类型,采用的是值拷贝的方式.也就是说,如果使用的是值类型,则可以在方法中更改该值,但当控制传递回调用过程时,不会保留更改的值.使用

详解SpringMVC中Controller的方法中参数的工作原理

前言 SpringMVC是目前主流的Web MVC框架之一. 如果有同学对它不熟悉,那么请参考它的入门blog:http://www.cnblogs.com/fangjian0423/p/springMVC-introduction.html SpringMVC中Controller的方法参数可以是Integer,Double,自定义对象,ServletRequest,ServletResponse,ModelAndView等等,非常灵活.本文将分析SpringMVC是如何对这些参数进行处理的,

Spring中的AOP(五)——在Advice方法中获取目标方法的参数

摘要: 本文介绍使用Spring AOP编程中,在增强处理方法中获取目标方法的参数,定义切点表达式时使用args来快速获取目标方法的参数. 获取目标方法的信息 访问目标方法最简单的做法是定义增强处理方法时,将第一个参数定义为JoinPoint类型,当该增强处理方法被调用时,该JoinPoint参数就代表了织入增强处理的连接点.JoinPoint里包含了如下几个常用的方法: Object[] getArgs:返回目标方法的参数 Signature getSignature:返回目标方法的签名 Ob

详解SpringMVC中Controller的方法中参数的工作原理[附带源码分析] good

目录 前言 现象 源码分析 HandlerMethodArgumentResolver与HandlerMethodReturnValueHandler接口介绍 HandlerMethodArgumentResolver与HandlerMethodReturnValueHandler接口的具体应用 常用HandlerMethodArgumentResolver介绍 常用HandlerMethodReturnValueHandler介绍 本文开头现象解释以及解决方案 编写自定义的HandlerMet

Effective Java - 方法的参数声明

给方法的参数加上限制是很常见的,比如参数代表索引时不能为负数.对于某个关键对象引用不能为null,否则会进行一些处理,比如抛出相应的异常信息. 对于这些参数限制,方法的提供者必须在文档中注明,并且在方法开头时检查参数,并在失败时提供明确的信息,即: detect errors as soon as possible after they occur 这将成为准确定位错误的一大保障. 如果没有做到这一点,最好的情况是方法在处理过程中失败并抛出了莫名其妙的异常,错误的源头变得难以定位,但这是最好的情

Effective Item 17 - 关于方法的参数声明

给方法的参数加上限制是很常见的,比如参数代表索引时不能为负数.对于某个关键对象引用不能为null,否则会进行一些处理,比如抛出相应的异常信息. 对于这些参数限制,方法的提供者必须在文档中注明,并且在方法开头时检查参数,并在失败时提供明确的信息,即detect errors as soon as possible after they occur,这将成为准确定位错误的一大保障. 如果没有做到这一点,最好的情况是方法在处理过程中失败并抛出了莫名其妙的异常,错误的源头变得难以定位,但这是最好的情况.

c# 方法传递参数

一.参数的使用方法: 1.值参数(Value Parameter ) 格式:方法名称(参数类型 参数名称[,参数类型 参数名称]) 2.引用参数(Reference Parameter ) 格式:方法名称(ref 参数类型 参数名称[,ref 参数类型 参数名称]) 3.输出参数(Out Parameter) 格式:方法名称(out 参数类型 参数名称[,out 参数类型 参数名称]) 二.值参数与引用参数及输出参数的区别: 2.1 值参数中实参的值不随形参值变更而变更: 形参与实参值互不影响,

C#中方法的参数修饰符

做项目久了,有的时候真的需要静下心来认真的总结一下自己所用到的技术,而不是每天依葫芦画瓢,每天忙忙碌碌,到头来不知道自己忙了个啥,学了什么,自己到底掌握了多少知识.所以我想回顾一下C#的基础知识,把重要的知识总结成点记录下来,方便以后快速阅读. 方法的参数及参数修饰符: 1.(无).如果一个参数没有用参数修饰符标记,则认为它将按值进行传递,这将意味着被调用的方法收到原始数据的一份副本. 2. out:输出参数由被调用的方法赋值,因此按引用传递,如果被调用的方法没有给输出参数赋值,就会出现编译错误