Pointer-network的tensorflow实现-1

pointer-network是最近seq2seq比较火的一个分支,在基于深度学习的阅读理解,摘要系统中都被广泛应用。

感兴趣的可以阅读原paper 推荐阅读

https://medium.com/@devnag/pointer-networks-in-tensorflow-with-sample-code-14645063f264

?
?

这个思路也是比较简单
就是解码的预测限定在输入的位置上
这在很多地方有用

比如考虑机器翻译的大词典问题,词汇太多了很多词是长尾的,词向量训练是不充分的,那么seq2seq翻译的时候很难翻译出这些词
另外专名什么的
很多是可以copy到
解码输出的

另外考虑文本摘要,很多时候就是要copy输入原文中的词,特别是长尾专名
更好的方式是copy而不是generate

?
?

网络上有一些pointer-network的实现,比较推荐

?https://github.com/ikostrikov/TensorFlow-Pointer-Networks

这个作为入门示例比较好,使用简单的static rnn 实现更好理解,当然 dynamic速度更快,但是从学习角度

先实现static更好一些。

Dynamic rnn的 pointer network实现

https://github.com/devsisters/pointer-network-tensorflow?

这里对static rnn实现的做了一个拷贝并做了小修改,改正了其中的一些问题
参见
https://github.com/chenghuige/hasky/tree/master/applications/pointer-network/static

?
?

这个小程序对应的应用是输入一个序列
比如,输出排序结果

?
?

我们的构造数据

python dataset.py

EncoderInputs: [array([[ 0.74840968]]), array([[ 0.70166106]]), array([[ 0.67414996]]), array([[ 0.9014052]]), array([[ 0.72811645]])]

DecoderInputs: [array([[ 0.]]), array([[ 0.67414996]]), array([[ 0.70166106]]), array([[ 0.72811645]]), array([[ 0.74840968]]), array([[ 0.9014052]])]

TargetLabels: [array([[ 3.]]), array([[ 2.]]), array([[ 5.]]), array([[ 1.]]), array([[ 4.]]), array([[ 0.]])]

?
?

训练过程中的eval展示:

2017-06-07 22:35:52 0:28:19 eval_step: 111300 eval_metrics:

[‘eval_loss:0.070‘, ‘correct_predict_ratio:0.844‘]

label--: [ 2 6 1 4 9 7 10 8 5 3 0]

predict: [ 2 6 1 4 9 7 10 8 5 3 0]

label--: [ 1 6 2 5 8 3 9 4 10 7 0]

predict: [ 1 6 2 5 3 3 9 4 10 7 0]

?
?

大概是这样
第一个我们认为是预测完全正确了,
第二个预测不完全正确

?
?

原程序最主要的问题是 Feed_prev 设置为True的时候 原始代码有问题的 因为inp使用的是decoder_input这是不正确的因为

预测的时候其实是没有decoder_input输入的,原代码预测的时候decoder input强制copy/feed了encoder_input

这在逻辑是是有问题的。 实验效果也证明修改成训练也使用encoder_input来生成inp效果好很多。

?
?

那么关于feed_prev我们知道在预测的时候是必须设置为True的因为,预测的时候没有decoder_input我们的下一个输出依赖

上一个预测的输出。

训练的时候我们是用decoder_input序列训练(feed_prev==False)还是也使用自身预测产生的结果进行下一步预测feed_prev==True呢

参考tensorflow官网的说明

In the above invocation, we set?feed_previous?to False. This means that the decoder will use?decoder_inputstensors as provided. If we set?feed_previous?to True, the decoder would only use the first element of?decoder_inputs. All other tensors from this list would be ignored, and instead the previous output of the decoder would be used. This is used for decoding translations in our translation model, but it can also be used during training, to make the model more robust to its own mistakes, similar to?Bengio et al., 2015?(pdf).

?
?

来自 <https://www.tensorflow.org/tutorials/seq2seq>

?
?

这里使用

train.sh

train-no-feed-prev.sh
做了对比实验

训练时候使用feed_prev==True效果稍好(红色) 特别是稳定性方差小一些

?
?

?
?

时间: 2024-08-07 02:41:05

Pointer-network的tensorflow实现-1的相关文章

Convolutional Neural Network in TensorFlow

翻译自Build a Convolutional Neural Network using Estimators TensorFlow的layer模块提供了一个轻松构建神经网络的高端API,它提供了创建稠密(全连接)层和卷积层,添加激活函数,应用dropout regularization的方法.本教程将介绍如何使用layer来构建卷积神经网络来识别MNIST数据集中的手写数字. MNIST数据集由60,000训练样例和10,000测试样例组成,全部都是0-9的手写数字,每个样例由28x28大小

(转)The Road to TensorFlow

Stephen Smith's Blog All things Sage 300… The Road to TensorFlow – Part 7: Finally Some Code leave a comment » Introduction Well after a long journey through Linux, Python, Python Libraries, the Stock Market, an Introduction to Neural Networks and tr

TensorFlow tutorial

代码示例来自https://github.com/aymericdamien/TensorFlow-Examples tensorflow先定义运算图,在run的时候才会进行真正的运算. run之前需要先建立一个session 常量用constant 如a = tf.constant(2) 变量用placeholder 需要指定类型 如a = tf.placeholder(tf.int16) 矩阵相乘 matrix1 = tf.constant([[3., 3.]]) #1*2矩阵 matrix

Seq2SQL :使用强化学习通过自然语言生成SQL

论文: https://einstein.ai/static/images/layouts/research/seq2sql/seq2sql.pdf 数据集:https://github.com/salesforce/WikiSQL Seq2SQL属于natural language interface (NLI)的领域,方便普通用户接入并查询数据库中的内容,即用户不需要了解SQL语句,只需要通过自然语言,就可查询所需内容. Seq2SQL借鉴的是Seq2Seq的思想,与Seq2Seq应用于机器

Abstractive Summarization

Abstractive Summarization A Neural Attention Model for Abstractive Sentence Summarization Alexander M. Rush et al., Facebook AI Research/Harvard EMNLP2015 sentence level seq2seq模型在2014年提出,这篇论文是将seq2seq模型应用在abstractive summarization任务上比较早期的论文.同组的人还发表了

问答系统总结

最近在研究问答系统,但是在查找资料的过程中一直处于懵逼状态,因为问答系统分类比较多,根据不同的依据可以分为不同种类,总是搞混,也没有找到资料详细全面的介绍,所以在一边学习查找资料的同时,自己也整理出一份总结,用于以后学习过程不至于思路混乱,如有错误请帮忙指出. 19世纪60年代最早:基于模板和规则 19世纪90年代:基于检索(IR)匹配-从问题中提取关键词,根据关键词在文本库中搜索相关文档,并进行降序排序,然后从文档中提取答案.        主要模型有:            单轮:DSSM,

组合神经优化涉及的一些知识

---恢复内容开始--- 1/ 注意力机制(attention mechanism) https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html 2/Pointer network https://www.translatoruser-int.com/translate?&to=en&csId=fa917e53-e092-4c2a-88c3-450a2fe4bf4e&usId=b194361b-a

[C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

About this Course This course will teach you the "magic" of getting deep learning to work well. Rather than the deep learning process being a black box, you will understand what drives performance, and be able to more systematically get good res

Tutorial: Implementation of Siamese Network on Caffe, Torch, Tensorflow

1. caffe version:  If you want to try this network, just do as the offical document said, like the following codes:   1 --- 2 title: Siamese Network Tutorial 3 description: Train and test a siamese network on MNIST data. 4 category: example 5 include