Pytorch中Module,Parameter和Buffer的区别

下文都将torch.nn简写成nn

  • Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类。
  • Buffer: buffer和parameter相对,就是指那些不需要参与反向传播的参数
    示例如下:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
        self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
        self.my_param = nn.Parameter(torch.randn(1))
    def forward(self, x):
        return x    

model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
  • Parameter: 是nn.parameter.Paramter,也就是组成Module的参数。例如一个nn.Linear通常由weightbias参数组成。它的特点是默认requires_grad=True,也就是说训练过程中需要反向传播的,就需要使用这个
import torch.nn as nn
fc = nn.Linear(2,2)

# 读取参数的方式一
fc._parameters
>>> OrderedDict([('weight', Parameter containing:
              tensor([[0.4142, 0.0424],
                      [0.3940, 0.0796]], requires_grad=True)),
             ('bias', Parameter containing:
              tensor([-0.2885,  0.5825], requires_grad=True))])

# 读取参数的方式二(推荐这种)
for n, p in fc.named_parameters():
    print(n,p)
>>>weight Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
bias Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

# 读取参数的方式三
for p in fc.parameters():
    print(p)
>>>Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

通过上面的例子可以看到,nn.parameter.Paramterrequires_grad属性值默认为True。另外上面例子给出了三种读取parameter的方法,推荐使用后面两种(这两种的区别可参阅Pytorch: parameters(),children(),modules(),named_*区别),因为是以迭代生成器的方式来读取,第一种方式是一股脑的把参数全丢给你,要是模型很大,估计你的电脑会吃不消。

另外需要介绍的是_parametersnn.Module__init__()函数中就定义了的一个OrderDict类,这个可以通过看下面给出的部分源码看到,可以看到还初始化了很多其他东西,其实原理都大同小异,你理解了这个之后,其他的也是同样的道理。

class Module(object):
    ...
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

每当我们给一个成员变量定义一个nn.parameter.Paramter的时候,都会自动注册到_parameters,具体的步骤如下:

import torch.nn as nn
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 下面两种定义方式均可
        self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
        print(self._parameters)
        self.p2 = nn.Paramter(torch.tensor(2.0))
        print(self._parameters)
  • 首先运行super(MyModel, self).__init__(),这样MyModel就初始化了_paramters等一系列的OrderDict,此时所有变量还都是空的。
  • self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 这行代码会触发nn.Module预定义好的__setattr__函数,该函数部分源码如下,:
def __setattr__(self, name, value):
    ...
    params = self.__dict__.get('_parameters')
    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules)
        self.register_parameter(name, value)
    ...

__setattr__函数作用简单理解就是判断你定义的参数是否正确,如果正确就继续调用register_parameter函数进行注册,这个函数简单概括就是做了下面这件事

def register_parameter(self,name,param):
    ...
    self._parameters[name]=param

下面我们实例化这个模型看结果怎样

model = MyModel()
>>>OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True))])
OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True)), ('p2', Parameter containing:
tensor(2., requires_grad=True))])

结果和上面分析的一致。

MARSGGBO?原创

如有意合作,欢迎私戳

邮箱:[email protected]

2019-12-20 21:11:02

原文地址:https://www.cnblogs.com/marsggbo/p/12075244.html

时间: 2024-10-18 15:30:50

Pytorch中Module,Parameter和Buffer的区别的相关文章

node.js中module.export与export的区别。

对module.exports和exports的一些理解 可能是有史以来最简单通俗易懂的有关Module.exports和exports区别的文章了. exports = module.exports = {}; 所以module.exports和exports的区别就是var a={}; var b=a;,a和b的区别 看起来木有什么太大区别,但实际用起来的时候却又有区别,这是为啥呢,请听我细细道来 关于Module.exports和exports有什么区别,网上一搜一大把,但是说的都太复杂了

学习笔记--【转】Parameter与Attribute的区别&servletContext与ServletConfig区别

原文链接http://blog.csdn.net/saygoodbyetoyou/article/details/9006001 Parameter与Attribute的区别 request.getParameter取得Web客户端到web服务端的http请求数据(get/post),只能是string类型的,而且HttpServletRequest没有对应的setParameter()方法. 如利用href(url)和form请求服务器时,表单数据通过parameter传递到服务器,且只能为字

parameter和argument的区别

在对一个函数写一个注释时,我在考虑到底该用parameter还是用argument来描述其参数呢.根据网上一些资料,对parameter和argument的区别,做如下的简单说明. 1. parameter是指函数定义中参数,而argument指的是函数调用时的实际参数. 2. 简略描述为:parameter=形参(formal parameter), argument=实参(actual parameter). 3. 在不很严格的情况下,现在二者可以混用,一般用argument,而parame

Linux Free命令每个数字的含义 和 cache 、buffer的区别

Linux Free命令每个数字的含义 和 cache .buffer的区别 我们按照图中来一细细研读(数字编号和图对应)1,total:物理内存实际总量2,used:这块千万注意,这里可不是实际已经使用了的内存哦,这里是总计分配给缓存(包含buffers 与cache )使用的数量,但其中可能部分缓存并未实际使用.3,free:未被分配的内存4,shared:共享内存5,buffers:系统分配的,但未被使用的buffer剩余量.注意这不是总量,而是未分配的量6,cached:系统分配的,但未

(原)CNN中的卷积、1x1卷积及在pytorch中的验证

转载请注明处处: http://www.cnblogs.com/darkknightzh/p/9017854.html 参考网址: https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d https://www.cnblogs.com/chuantingSDU/p/8120065.html https://blog.csdn.net/chaolei3/article/details/79374563 1x1

module.exports和exports得区别

对module.exports和exports的一些理解 可能是有史以来最简单通俗易懂的有关Module.exports和exports区别的文章了. exports = module.exports = {}; 所以module.exports和exports的区别就是var a={}; var b=a;,a和b的区别 看起来木有什么太大区别,但实际用起来的时候却又有区别,这是为啥呢,请听我细细道来 关于Module.exports和exports有什么区别,网上一搜一大把,但是说的都太复杂了

[Pytorch]Pytorch中tensor常用语法

原文地址:https://zhuanlan.zhihu.com/p/31494491 上次我总结了在PyTorch中建立随机数Tensor的多种方法的区别. 这次我把常用的Tensor的数学运算总结到这里,以防自己在使用PyTorch做实验时,忘记这些方法应该传什么参数. 总结的方法包括: Tensor求和以及按索引求和:torch.sum() torch.Tensor.indexadd() Tensor元素乘积:torch.prod(input) 对Tensor求均值.方差.极值: torch

Pytorch中的自编码(autoencoder)

Pytorch中的自编码(autoencoder) 本文资料来源:https://www.bilibili.com/video/av15997678/?p=25 什么是自编码 先压缩原数据.提取出最有代表性的信息.然后处理后再进行解压.减少处理压力 通过对比白色X和黑色X的区别(cost函数),从而不断提升自编码模型的能力(也就是还原的准确度) 由于这里只是使用了数据本身,没有使用label,所以可以说autoencoder是一种无监督学习模型. 实际在使用中,我们先训练好一个autoencod

SQL中varchar和nvarchar有什么区别?

varchar(n)长度为 n 个字节的可变长度且非 Unicode 的字符数据.n 必须是一个介于 1 和 8,000 之间的数值.存储大小为输入数据的字节的实际长度,而不是 n 个字节. nvarchar(n)包含 n 个字符的可变长度 Unicode 字符数据.n 的值必须介于 1 与 4,000 之间.字节的存储大小是所输入字符个数的两倍. 两字段分别有字段值:我和coffee那么varchar字段占2×2+6=10个字节的存储空间,而nvarchar字段占8×2=16个字节的存储空间.