pytorch对可变长度序列的处理

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

  • 一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)
  • 一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,Bbatch size*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable

参数说明:

  • input (Variable) – 变长序列 被填充后的 batch
  • lengths (list[int]) – Variable 中 每个序列的长度。
  • batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的sizeT×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*

Batch中的元素将会以它们长度的逆序排列。

参数说明:

  • sequence (PackedSequence) – 将要被填充的 batch
  • batch_first (bool, optional) – 如果为True,返回的数据的格式为 B×T×*

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1

tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step

# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)

# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))

#forward
out, _ = rnn(pack, h0)

# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print(‘111‘,unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
  0.5406  0.3584
 -0.1403  0.0308

(1 ,.,.) =
 -0.6855 -0.9307
  0.0000  0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])
时间: 2024-11-03 10:32:16

pytorch对可变长度序列的处理的相关文章

库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 PyTorch 其他项目 自然语言处理和语音处理 该部分项目涉及语音识别.多说话人语音处理.机器翻译.共指消解.情感分类.词嵌入/表征.语音生成.文本语音转换.视觉问答等任务,其中有一些是具体论文的 PyTorch 复现,此外还包括一些任务更广泛的库.工具集.框架. 这些项目有很多是官方的实现,其中

python课程第二周 内置数据结构——列表和元组

5种内置数据结构:列表.元组.字典.集合.字符串.列表.字典.字符串三种被称为线性结构. 针对线性结构的操作有:切片.封包和解包.成员运算符.迭代. 针对数据结构的操作有解析式:解析式分为列表解析.生成器解析.集合解析和字典解析. 后面三种是Python3.x特有的. 基本框架如下: 一.列表:Python中最具灵活性的有序集合对象类型 列表可包含任何种类的对象:数字.字符串.字典.集合甚至其他列表,这个特性称为异构.还具有可变长度和任意嵌套的特性,属于可变长度序列. (1)列表的初始化,有两种

http2协议翻译(转)

超文本传输协议版本 2 IETF HTTP2草案(draft-ietf-httpbis-http2-13) 摘要 本规范描述了一种优化的超文本传输协议(HTTP).HTTP/2通过引进报头字段压缩以及多路复用来更有效利用网络资源.减少感知延迟.另外还介绍了服务器推送规范. 本文档保持对HTTP/1.1的后向兼容,HTTP的现有的语义保持不变. 1 介绍 The Hypertext Transfer Protocol (HTTP) is a wildly successful protocol.

PyTorch 1.0 中文官方教程:序列模型和LSTM网络

译者:ETCartman 之前我们已经学过了许多的前馈网络. 所谓前馈网络, 就是网络中不会保存状态. 然而有时 这并不是我们想要的效果. 在自然语言处理 (NLP, Natural Language Processing) 中, 序列模型是一个核心的概念. 所谓序列模型, 即输入依赖于时间信息的模型. 一个典型的序列模型是隐马尔科夫模型 (HMM, Hidden Markov Model). 另一个序列模型的例子是条件随机场 (CRF, Conditional Random Field). 循

pytorch中如何处理RNN输入变长序列padding

一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样. 比如向下图这样: 但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就

Pytorch基础——使用 RNN 生成简单序列

一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无关语法.然后再让模型尝试去生成这样的字符串.在流程中将演示 RNN 及 LSTM 相关函数的使用方法. 实验知识点 什么是上下文无关文法 使用 RNN 或 LSTM 模型生成简单序列的方法 探究 RNN 记忆功能的内部原理 二.什么是上下文无关语法 上下文无关语法 首先让我们观察以下序列: 01 0

pytorch 对变长序列的处理

使用的主要部分包括:Dateset. Dateloader.MSELoss.PackedSequence.pack_padded_sequence.pad_packed_sequence 模型包含LSTM模块. 参考了下面两篇博文,总结了一下. http://www.cnblogs.com/lindaxin/p/8052043.html#commentform https://blog.csdn.net/lssc4205/article/details/79474735 使用Dateset构建数

数据库 day60,61 Oracle入门,单行函数,多表查询,子查询,事物处理,约束,rownum分页,视图,序列,索引

1.    oracle介绍 ORACLE数据库系统是美国ORACLE公司(甲骨文)提供的以分布式数据库为核心的一组软件产品,是目前最流行的客户/服务器(CLIENT/SERVER)或B/S体系结构的数据库之一.比如SilverStream就是基于数据库的一种中间件.ORACLE数据库是目前世界上使用最为广泛的数据库管理系统,作为一个通用的数据库系统,它具有完整的数据管理功能:作为一个关系数据库,它是一个完备关系的产品:作为分布式数据库它实现了分布式处理功能.但它的所有知识,只要在一种机型上学习

IO包中的其他类 打印流,序列流,操作对象,管道流,RandomAccessFile,操作基本数据类型,操作字节数组

打印流,序列流,操作对象,管道流,RandomAccessFile,操作基本数据类型,操作字节数组 一.打印流: 该流提供了打印方法,可以将各种数据类型的数据都原样打印. 字节打印流PrintStream构造函数可以接收的参数类型1.File对象 File2.字符串路径 String3.字节输出流 OutputStream 字符打印流PrintWriter(更常用)1.File对象 File2.字符串路径 String3.字节输出流 OutputStream4.字符输出流 Writer publ