keras中的mask操作

使用背景

最常见的一种情况, 在NLP问题的句子补全方法中, 按照一定的长度, 对句子进行填补和截取操作. 一般使用keras.preprocessing.sequence包中的pad_sequences方法, 在句子前面或者后面补0. 但是这些零是我们不需要的, 只是为了组成可以计算的结构才填补的. 因此计算过程中, 我们希望用mask的思想, 在计算中, 屏蔽这些填补0值得作用. keras中提供了mask相关的操作方法.

原理

在keras中, Tensor在各层之间传递, Layer对象接受的上层Layer得到的Tensor, 输出的经过处理后的Tensor.

keras是用一个mask矩阵来参与到计算当中, 决定在计算中屏蔽哪些位置的值. 因此mask矩阵其中的值就是True/False, 其形状一般与对应的Tensor相同. 同样与Tensor相同的是, mask矩阵也会在每层Layer被处理, 得到传入到下一层的mask情况.

使用方法

  1. 最直接的, 在NLP问题中, 对句子填补之后, 就要输入到Embedding层中, 将tokenid转换成对应的vector. 我们希望被填补的0值在后续的计算中不产生影响, 就可以在初始化Embedding层时指定参数mask_zeroTrue, 意思就是屏蔽0值, 即填补的0值.

    Embedding层中的compute_mask方法中, 会计算得到mask矩阵. 虽然在Embedding层中不会使用这个mask矩阵, 即0值还是会根据其对应的向量进行查找, 但是这个mask矩阵会被传入到下一层中, 如果下一层, 或之后的层会对mask进行考虑, 那就会起到对应的作用.

  2. 也可以在keras.layers包中引用Masking类, 使用mask_value指定固定的值被屏蔽. 在调用call方法时, 就会输出屏蔽后的结果.

    需要注意的是Masking这种层的compute_mask方法, 源码如下:

    def compute_mask(self, inputs, mask=None):
        output_mask = K.any(K.not_equal(inputs, self.mask_value), axis=-1)
        return output_mask

    可以看到, 这一层输出的mask矩阵, 是根据这层的输入得到的, 具体的说是会比输入第一个维度, 这是因为最后一个维度被K.any(axis=-1)给去掉了. 在使用时需要注意这种操作的意义以及维度的变化.

自定义使用方法

更多的, 我们还是在自定义的层中, 需要支持mask操作, 因此需要对应的逻辑.



首先, 如果我们希望自定义的这个层支持mask操作, 就需要在__init__方法中指定:

self.supports_masking = True

如果在本层计算中需要使用到mask, 则call方法需要多传入一个mask参数, 即:

def call(self, inputs, mask=None):
    pass

然后, 如果还要继续输出mask, 供之后的层使用, 如果不对mask矩阵进行变换, 这不用进行任何操作, 否则就需要实现compute_mask函数:

def compute_mask(self, inputs, mask=None):
    pass

这里的inputs就是输入的Tensor, 与call方法中接收到的一样, mask就是上层传入的mask矩阵.

如果希望mask到此为止, 之后的层不再使用, 则该函数直接返回None即可:

def compute_mask(self, inputs, mask=None):
    return None

参考资料

Keras自定义实现带masking的meanpooling层

Keras实现支持masking的Flatten层

原文地址:https://www.cnblogs.com/databingo/p/9339175.html

时间: 2024-11-05 19:41:53

keras中的mask操作的相关文章

keras中的loss、optimizer、metrics

用keras搭好模型架构之后的下一步,就是执行编译操作.在编译时,经常需要指定三个参数 loss optimizer metrics 这三个参数有两类选择: 使用字符串 使用标识符,如keras.losses,keras.optimizers,metrics包下面的函数 例如: sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', opt

sql的基础语句-select语句中出现的操作符号

2. select语句中出现的操作符号 2.1 合并操作符select a.ename||' '||to_char(sal) from emp a; 2.2 消除重复的行 select distinct deptno from emp; 2.3 空格.空串.null的区别 select ascii(' '),ascii(null),ascii('') from dual; 区别:  从显式上看,空串跟null在数据库中存储的值是一样的,但是NULL可以赋给任何数据类型,而空串只能赋给字符串类型

C# Winform中执行post操作并获取返回的XML类型的数据

/// <summary> /// 返回指定日期的订单数据 /// </summary> /// <param name="StartDate">起始日期</param> /// <param name="EndDate">结束日期</param> /// <returns>DataTable</returns> public System.Data.DataTable

Hibernate 中对表的操作

Hibernate CRUD testing Hibernate 中对表的操作, add,load,update,delete,list,pager(分页) package org.test.test; import java.text.SimpleDateFormat; import java.util.List; import org.hibernate.Session; import org.junit.Test; import org.zttc.itat.model.User; impo

Python中的切片操作

Python中的切片操作功能十分强大,通常我们利用切片来进行提取信息,进行相关的操作,下面就是一些切片的列子,一起来看看吧,希望对大家学习python有所帮助. 列如我们从range函数1-100中取7的倍数,函数及结果如下所示: >>> for i in range(1,100)[6::7]: print i 7 14 21 28 35 42 49 56 63 70 77 84 91 98 取一个list或tuple的部分元素是非常常见的操作.比如,一个list如下: >>

使用Json.Net解决MVC中各种json操作

最近收集了几篇文章,用于替换MVC中各种json操作,微软mvc当然用自家的序列化,速度慢不说,还容易出问题,自定义性也太差,比如得特意解决循环引用的问题,比如datetime的序列化格式,比如性能.NewtonSoft.json也就是Json.Net性能虽然不是最好的,但是是比较靠前的,其功能是最强大的,包含各种json操作模式.现在来看看mvc中的替换1, Controller.Json方法这个方法最容易出现循环引用,比如EF查出一个一对多集合想序列化,结果a引用了子表b,b中还引用了a,导

第十三篇:multimap容器和multiset容器中的find操作

前言 multimap容器是map容器的“ 增强版 ”,它允许一个键对应多个值.对于map容器来说,find函数将会返回第一个键值匹配元素所在处的迭代器.那么对于multimap容器来说,find函数将如何运作呢?如果要实现和map容器的find函数同样的功能,则它将返回多个迭代器,这样太复杂了.本文将讲解C++中multimap容器的“ find实现 ”. 解决思路一 摒弃find函数,使用另外两个新函数,它们是专家们为了解决multimap中的“ find操作 ”问题专门设计的: 1. lo

git工作中的常用操作

上班开始,打开电脑,git pull:拉取git上最新的代码: 编辑代码,准备提交时,git stash:将自己编辑的代码暂存起来,防止git pull时与库中的代码起冲突,否则自己的代码就白敲了: 然后,git pull:拉取一下代码,与库中代码,做到同步,有冲突则解决冲突,如果省了这一步,别人有提交的代码,没有更新,自己提交就会报错,再走这一步,就会把别人的代码拉取出来,然后一起提交,就相当于你提交了自己的代码,也提交了别人的代码:还有,有时这样会使库中代码乱掉,别人的心血也会丢失,你就是罪

CI中的AR操作

1 /** 2 * CI 中的 AR 操作 3 * @author zhaoyingnan 4 **/ 5 public function mAR() 6 { 7 /*************** 查询 *************/ 8 //select * from mp4ba limit 21,10; 9 //$objResult = $this->db->get('mp4ba', 10, 21); 10 //echo $this->db->last_query();die;