pytorch中如何处理RNN输入变长序列padding

一、为什么RNN需要处理变长输入

假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示:

思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样。

比如向下图这样:

但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:

那么我们正确的做法应该是怎么样呢?

这就引出pytorch中RNN需要处理变长输入的需求了。在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示:如下图:

二、pytorch中RNN如何处理变长padding

主要是用函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这两个函数的用法。

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后(特别注意需要进行排序)。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

参数说明:

input (Variable) – 变长序列 被填充后的 batch

lengths (list[int]) – Variable 中 每个序列的长度。(知道了每个序列的长度,才能知道每个序列处理到多长停止

batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。一个PackedSequence表示如下所示:

具体代码如下:

embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)
encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)

此时,返回的h_last和c_last就是剔除padding字符后的hidden state和cell state,都是Variable类型的。代表的意思如下(各个句子的表示,lstm只会作用到它实际长度的句子,而不是通过无用的padding字符,下图用红色的打钩来表示):

但是返回的output是PackedSequence类型的,可以使用:

encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)

将encoderoutputs在转换为Variable类型,得到的_代表各个句子的长度。

三、总结

这样综上所述,RNN在处理类似变长的句子序列的时候,我们就可以配套使用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来避免padding对句子表示的影响

参考:

原文地址:https://www.cnblogs.com/jfdwd/p/11184861.html

时间: 2024-10-04 04:23:53

pytorch中如何处理RNN输入变长序列padding的相关文章

Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()

torch.nn.utils.rnn.pack_padded_sequence() 这里的pack,理解成压紧比较好. 将一个 填充过的变长序列 压紧.(填充时候,会有冗余,所以压紧一下) 其中pack的过程为:(注意pack的形式,不是按行压,而是按列压) (下面方框内为PackedSequence对象,由data和batch_sizes组成) 输入的形状可以是(T×B×* ).T是最长序列长度,B是batch size,*代表任意维度(可以是0).如果batch_first=True的话,那

keras: 在构建LSTM模型时,使用变长序列的方法

众所周知,LSTM的一大优势就是其能够处理变长序列.而在使用keras搭建模型时,如果直接使用LSTM层作为网络输入的第一层,需要指定输入的大小.如果需要使用变长序列,那么,只需要在LSTM层前加一个Masking层,或者embedding层即可. from keras.layers import Masking, Embedding from keras.layers import LSTM model = Sequential() model.add(Masking(mask_value=

pytorch 对变长序列的处理

使用的主要部分包括:Dateset. Dateloader.MSELoss.PackedSequence.pack_padded_sequence.pad_packed_sequence 模型包含LSTM模块. 参考了下面两篇博文,总结了一下. http://www.cnblogs.com/lindaxin/p/8052043.html#commentform https://blog.csdn.net/lssc4205/article/details/79474735 使用Dateset构建数

PyTorch 1.0 中文官方教程:序列模型和LSTM网络

译者:ETCartman 之前我们已经学过了许多的前馈网络. 所谓前馈网络, 就是网络中不会保存状态. 然而有时 这并不是我们想要的效果. 在自然语言处理 (NLP, Natural Language Processing) 中, 序列模型是一个核心的概念. 所谓序列模型, 即输入依赖于时间信息的模型. 一个典型的序列模型是隐马尔科夫模型 (HMM, Hidden Markov Model). 另一个序列模型的例子是条件随机场 (CRF, Conditional Random Field). 循

scala学习手记21 - 传递变长参数

在Java中是可以使用变长参数的,如下面的方法: public void check(String ... args){ for(String tmp : args){ System.out.println(tmp); } } 在scala中也可以使用变长参数.和java一样,也是只有最后一个参数可以接收可变长度的参数.使用方式是在参数类型后使用特殊符号"*",如下面的max()方法: def max(values: Int*) = values.foldLeft(values(0))

Pytorch基础——使用 RNN 生成简单序列

一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无关语法.然后再让模型尝试去生成这样的字符串.在流程中将演示 RNN 及 LSTM 相关函数的使用方法. 实验知识点 什么是上下文无关文法 使用 RNN 或 LSTM 模型生成简单序列的方法 探究 RNN 记忆功能的内部原理 二.什么是上下文无关语法 上下文无关语法 首先让我们观察以下序列: 01 0

输入某二叉树的前序遍历和中序遍历的结果,请重建出该二叉树。假设输入的前序遍历和中序遍历的结果中都不含重复的数字。例如输入前序遍历序列{1,2,4,7,3,5,6,8}和中序遍历序列{4,7,2,1,5,3,8,6},则重建二叉树并返回。

问题描述: 输入某二叉树的前序遍历和中序遍历的结果,请重建出该二叉树.假设输入的前序遍历和中序遍历的结果中都不含重复的数字.例如输入前序遍历序列{1,2,4,7,3,5,6,8}和中序遍历序列{4,7,2,1,5,3,8,6},则重建二叉树并返回. 思路: 在二叉树的前序遍历序列中,第一个数字总是树的根结点的值.但在中序遍历序列中,根结点的值在序列的中间,左子树的结点的值位于根结点的值的左边,而右子树的结点的值位于根结点的值的右边.因此我们需要扫描中序遍历序列,才能找到根结点的值. 如下图所示,

C 语言变长数组 struct 中 char data[0] 的用法

1.结构体内存布局(padding) 为了让CPU能够更舒服地访问到变量,struct中的各成员变量的存储地址有一套对齐的机制.这个机制概括起来有两点:第一,每个成员变量的首地址,必须是它的类型的对齐值的整数倍,如果不满足,它与前一个成员变量之间要填充(padding)一些无意义的字节来满足:第二,整个struct的大小,必须是该struct中所有成员的类型中对齐值最大者的整数倍,如果不满足,在最后一个成员后面填充. The following typical alignments are va

Objective-C中实现变长参数问题

版权声明:本文为博主原创文章,未经博主允许不得转载. ObjC中没有提供直接的变长参数方法,需要使用C标准库中的av_list方法,简单使用如下: -(void)somethingForyou:(NSString *)vString,....{ va_list varList; id arg; NSMutableArray *argsArray = [[NSMutableArray alloc]inti]; if(vString){ va_start(varList,vString); whil