pytorch 对变长序列的处理

使用的主要部分包括:Dateset、 Dateloader、MSELoss、PackedSequence、pack_padded_sequence、pad_packed_sequence

模型包含LSTM模块。

参考了下面两篇博文,总结了一下。

http://www.cnblogs.com/lindaxin/p/8052043.html#commentform

https://blog.csdn.net/lssc4205/article/details/79474735

使用Dateset构建数据集的时候,在__getitem__函数中

def __getitem__(self, index):
    ‘‘‘
get original data
此处省略获取原始数据的代码
input_data,output_data
数据shape是  seq_length * feature_dim
    ‘‘‘
# 当前seq_length小于所有数据中的最长数据长度,则补0到同一长度。
    ori_length = input_data.shape[0]
    if ori_length < self.max_len:
        npi = np.zeros(self.input_feature_dim, dtype=np.float32)
        npi = np.tile(npi, (self.max_len - ori_length,1))
        input_data = np.row_stack((input_data, npi))
        npo = np.zeros(self.output_feature_dim, dtype=np.float32)
        npo = np.tile(npo, (self.max_len - ori_length,1))
        output_data = np.row_stack((output_data, npo))
    return input_data, output_data, ori_length, input_data_path

在模型中,forward的实现中,需要在LSTM之前使用pack_padded_sequence、在LSTM之后使用pad_packed_sequence,中间还涉及到顺序的还原之类的操作。

def forward(self, input_x, length_list, hidden=None):
    if hidden is None:
        # 这里没用 配置中的batch_size,而是直接在input_x中取batch_size是为了防止last_batch的batch_size不是配置中的那个,引发bug
        h_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
        c_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
    else:
        h_0, c_0 = hidden
    ‘‘‘
省略模型其他部分,直接进去LSTM前后的操作
    ‘‘‘
    _, idx_sort = torch.sort(length_list, dim=0, descending=True)
    _, idx_unsort = otrch.sort(idx_sort, dim=0)

    input_x = input_x.index_select(0, Variable(idx_sort))
    length_list = list(length_list[idx_sort])
    pack = nn_utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=self.batch_first)
    output, hidden = self.BiLSTM(pack, (h0, c0))
    un_padded = nn_utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
    un_padded = un_padded[0].index_select(0, Variable(idx_unsort))
# 此时的un_padded已经完成了还原,并且补0完成,而且这时的补0到的序列长度是当前batch的最长长度,而不是Dateset中的全局最长长度!# 所以在main train函数中也要对label的seq做处理
    return un_padded

main train中,要对label做相应的截断处理,算loss的时候,MSELoss的reduce参数要设置成false,让loss函数返回一个loss矩阵,再构造一个掩膜矩阵mask,矩阵相乘求和得到真的loss(达到填充0的位置不参与loss的目的)

def train(**kwargs):  train_data = my_dataset()  train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)  model = getattr(models, opt.model)(batchsize=opt.batch_size)  criterion = torch.nn.MSELoss(reduce=False)  lr = opt.lf  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)  for epoch in range(opt.start_epoch, opt.max_epoch):    for ii, (data, label, length_list,_) in tqdm(enumerate(train_dataloader)):      cur_batch_max_len = length_list.max()      data = Variable(data)      target = Variable(label)

      optimizer.zero_grad()      score = model(data, length_list)      loss_mat = criterion(score, target)      list_int = list(length_list)      mask_mat = Variable(t.ones(len(list_int),cur_batch_max_len,opt.output_feature_dim))      num_element = 0      for idx_sample in range(len(list_int)):        num_element += list_int[idx_sample] * opt.output_feature_dim        if list_int[idx_sample] != cur_batch_max_len:          mask_mat[idx_sample, list[idx_sample]:] = 0.0

      loss = (loss_mat * mask_mat).sum() / num_element      loss.backward()      optimizer.step()

原文地址:https://www.cnblogs.com/chengebigdata/p/8993990.html

时间: 2024-10-08 16:44:48

pytorch 对变长序列的处理的相关文章

pytorch中如何处理RNN输入变长序列padding

一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样. 比如向下图这样: 但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就

keras: 在构建LSTM模型时,使用变长序列的方法

众所周知,LSTM的一大优势就是其能够处理变长序列.而在使用keras搭建模型时,如果直接使用LSTM层作为网络输入的第一层,需要指定输入的大小.如果需要使用变长序列,那么,只需要在LSTM层前加一个Masking层,或者embedding层即可. from keras.layers import Masking, Embedding from keras.layers import LSTM model = Sequential() model.add(Masking(mask_value=

scala学习笔记-变长参数(5)

变长参数 在Scala中,有时我们需要将函数定义为参数个数可变的形式,则此时可以使用变长参数定义函数. 1 def sum(nums: Int*) = { 2 var res = 0 3 for (num <- nums) res += num 4 res 5 } 6 7 sum(1, 2, 3, 4, 5) 使用序列调用变长参数 在如果想要将一个已有的序列直接调用变长参数函数,是不对的.比如val s = sum(1 to 5).此时需要使用Scala特殊的语法将参数定义为序列,让Scala解

C++11 新特性之 变长参数模板

template <typename ... ARGS> void fun(ARGS ... args) 首先明确几个概念 1,模板参数包(template parameter pack):它指模板参数位置上的变长参数,例如上面例子中的ARGS 2,函数参数包(function parameter pack):它指函数参数位置上的变长参数,例如上面例子中的args 一般情况下 参数包必须在最后面,例如: template <typename T, typename ... Args>

Java语法糖初探(三)--变长参数

变长参数概念 在Java5 中提供了变长参数(varargs),也就是在方法定义中可以使用个数不确定的参数,对于同一方法可以使用不同个数的参数调用.形如 function(T -args).但是需要明确的一点是,java方法的变长参数只是语法糖,其本质上还是将变长的实际参数 varargs 包装为一个数组. 看下面的例子: 12345678910111213 public class VariVargs { public static void main(String []args) { tes

java 变长參数使用原则

1.java变长參数用...表示,如Print(String... args){  ... }; 2.假设一个调用既匹配一个固定參数方法.又匹配一个变长參数方法,则优先匹配固定參数的方法 3.假设一个调用能匹配两个及以上的变长參数方法,则出现错误--这事实上表示方法设计有问题,编译器会提示The method is ambiguous 4.方法仅仅能有一个变长參数,且必须放在參数列表的最后一个

变长结构体的使用

在分析安卓源码过程中看到几处使用变长结构体的例子,比如下面的结构体: struct command { /* list of commands in an action */ struct listnode clist; int (*func)(int nargs, char **args); int nargs; char *args[1]; }; 下面介绍安卓是如何使用这个结构的,在解析init.rc文件的过程中,会使用这个结构体记录某些命令. static void parse_line_

变长数组_相乘取结果

//变长数组 相乘取结果 #include <stdio.h> int main(void){ // int array_01[3][4] = {1,2,3,4,5,6,7,8,9,10,11,12}; int array_02[4][3] = {12,11,10,9,8,7,6,5,4,3,2,1}; int result[3][3] = {0}; int i, j, k; for (i = 0; i < 3; i ++){ //遍历array_01数组元素 for (j = 0;j

读书笔记:c语言标准库 - 变长参数

· 变长参数(stdarg.h) 变长参数是c语言的特殊参数形式,例如如下函数声明: int printf(const char * format,...); 如此的声明表明,printf函数除了第一个参数类型为const char*之外,其后可以追加任意数量.任意类型的参数. 在函数实现部分,可以使用stdarg.h里的多个宏来访问各个额外的参数:假设lastarg是变长参数函数的最后一个具名参数(printf里的format),那么在函数内容定义类型为va_list的变量: va_list