Neural Turing Machines-NTM系列(三)ntm-lasagne源码分析

Neural Turing Machines-NTM系列(三)ntm-lasagne源码分析

在NTM系列文章(二)中,我们已经成功运行了一个ntm工程的源代码。在这一章中,将对它的源码实现进行分析。

1.网络结构

1.1 模块结构图

在图中可以看到,输入的数据在经过NTM的处理之后,输出经过NTM操作后的,跟之前大小相同的数据块。来看下CopyTask的完整输出图:

图中右侧的Input是输入数据,Output是目标数据,Prediction是通过NTM网络预测出来的输出数据,可以看出预测数据与目标数据只在区域上大致相同,具体到每个白色的块差距较大。(这里只迭代训练了100次)

训练次数可以在这里调整(task-copy.py):

其中的参数max_iter就是训练时的迭代次数,size是输入的数据宽度(即上图中Input/Output小矩形的“高”-1,多出来的维度用作结束标记)

输入数据如下,从上到下对应上图中的从左到右,最后一行是结束标志,只有最后一个元素为1:

array( [[

[ 0., 1., 1., 0., 1., 1., 1., 1., 0.],

[ 0., 1., 1., 0., 0., 1., 0., 0., 0.],

[ 0., 0., 1., 0., 1., 1., 1., 0., 0.],

[ 1., 1., 1., 1., 1., 1., 1., 0., 0.],

[ 0., 0., 0., 0., 0, 0., 0., 0., 1.],

[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],

[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],

[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],

[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],

[ 0., 0., 0., 0., 0, 0., 0., 0., 0.],

]]

目标数据和预测数据的格式相似,就不详细介绍了。需要注意的是,由于输出层使用的是sigmoid函数,所以预测数据的范围在0和1之间。

1.2 Head对象内部的计算流

上图对应的实现在ntm-lasagne/ntm/heads.py中的Head基类中的get_output_for函数

    def get_output_for(self, h_t, w_tm1, M_t, **kwargs):
        if self.sign is not None:
            sign_t = self.sign.get_output_for(h_t, **kwargs)
        else:
            sign_t = 1.
        k_t = self.key.get_output_for(h_t, **kwargs)
        beta_t = self.beta.get_output_for(h_t, **kwargs)
        g_t = self.gate.get_output_for(h_t, **kwargs)
        s_t = self.shift.get_output_for(h_t, **kwargs)
        gamma_t = self.gamma.get_output_for(h_t, **kwargs)

        # Content Adressing (3.3.1)
        beta_t = T.addbroadcast(beta_t, 1)
        betaK = beta_t * similarities.cosine_similarity(sign_t * k_t, M_t)
        w_c = lasagne.nonlinearities.softmax(betaK)

        # Interpolation (3.3.2)
        g_t = T.addbroadcast(g_t, 1)
        w_g = g_t * w_c + (1. - g_t) * w_tm1

        # Convolutional Shift (3.3.2)
        w_g_padded = w_g.dimshuffle(0, ‘x‘, ‘x‘, 1)
        conv_filter = s_t.dimshuffle(0, ‘x‘, ‘x‘, 1)
        pad = (self.num_shifts // 2, (self.num_shifts - 1) // 2)
        w_g_padded = padding.pad(w_g_padded, [pad], batch_ndim=3)
        convolution = T.nnet.conv2d(w_g_padded, conv_filter,
            input_shape=(self.input_shape[0], 1, 1, self.memory_shape[0] + pad[0] + pad[1]),
            filter_shape=(self.input_shape[0], 1, 1, self.num_shifts),
            subsample=(1, 1),
            border_mode=‘valid‘)
        w_tilde = convolution[:, 0, 0, :]

        # Sharpening (3.3.2)
        gamma_t = T.addbroadcast(gamma_t, 1)
        w = T.pow(w_tilde + 1e-6, gamma_t)
        w /= T.sum(w)

        return w

其中的传入参数解释如下:

h_t:controller的隐层输出;

w_tm1:前一时刻的输出值,即wt?1;

M_t:Memory矩阵

1.3 NTMLayer结构图

NTM层的数据处理实现在ntm-lasagne/ntm/layers.py中的NTMLayer.get_output_for函数中:

注意到其中还有一个内部函数step,这个函数中实现了每一次数据输入后NTM网络要进行的操作逻辑。

其中的参数解释如下:

x_t:当前的网络输入,即1.1中输入矩阵中的一行;

M_tm1:前一时刻的Memory矩阵,即Mt?1

h_tm1:前一时刻的controller隐层输出

state_tm1:前一时刻的controller隐层状态,当controller为前馈网络时,等于前一时刻的输出

params:存放write heads和read heads上一时刻的输出即wt?1,顺序如下:

[write_head1_w,write_head2_w,…,write_headn1_w,read_head1_w,read_head2_w,…,read_headn2_w]

1.每次网络接收到输入后,会进入step迭代函数,先走write(erase+add)流程,更新Memory,然后再执行read操作,生成rt向量。这部分代码如下:

最后的r_t就是读取出来的rt向量,注意这里有个比较特殊的参数W_hid_to_sign_add,这是一个开关参数,类似于LSTM中的“门”。这个参数默认为None。

2.read vector生成后,将作为输入参数被传入Controller:

3.step函数结束,返回值为一list,代码如下:

list中的元素依次为:[M_t, h_t, state_t + write_weights_t + read_weights_t]

step函数通过 theano.scan来进行迭代调用,每次的输入即为当前的input及上一时刻的list值

4.最后NTMLayer.get_out_for函数的返回值为:

hid_out = hids[1],正好对应了Controller隐层最近一次的输出值。

1.4 NTM网络结构图

2.公式及主要Class说明

αt=σalpha(htWalpha+balpha)

kt=σkey(htWkey+bkey)

βt=σbeta(htWbeta+bbeta)

gt=σgate(htWgate+bgate)

st=σshift(htWshift+bshift)

γt=σgamma(htWgamma+bgamma)

wct=softmax(βt?K(αt?kt,Mt))

wgt=gt?wct+(1?gt)?wt?1

w? t=st?wgt

wt∝w? γtt

NTMLayer:父类为 lasagne.layers.Layer

功能:Neural Turing Machine的框架层

字段:memory:即Memory

controller:控制器,父类为Layer,默认100个节点

controller.hid_init:隐层的状态集合,大小为:(1,100)

heads:读写取Head集合

write_heads:写入Head集合

read_heads:读取Head集合

函数:get_output_for:在给定的输入input下,返回对应的输出值

Head:父类为lasagne.layers.Layer

功能:读写头的基类

字段:sign:DenseLayer(全连接网络),输出为αt,激活函数为ClippedLinear(-1,1),节点数:20;

key:DenseLayer,输出为kt,激活函数为ClippedLinear(0,1),节点数:20,输入层为controller;

beta:DenseLayer,输出为βt,激活函数为rectify,节点数:1,输入层为controller;

gate:DenseLayer,输出为gt,激活函数为hard_sigmoid,节点数:1,输入层为controller;

shift:DenseLayer,输出为st,激活函数为softmax,节点数:3(等于num_shifts,默认为3),输入层为controller,最终将输出3个概率值,分别对应st(?1),st(0),st(1),s_{t}长度为N,除softmax输出的3个位置非0之外,其余位置为0;

gamma:DenseLayer,输出为γt,激活函数为1+rectify,节点数:1,输入层为controller;

num_shifts:卷积shifts的操作宽度(奇数),当宽度为n时,移位向量为:[-n/2,…,-1,0,1,…,n/2],比如,当n=3时,为:[-1,0,1]

weights_init:输出为OneHot1×128的权值向量,其初始值为除第一个元素为1之外,其余元素为0.

gate:DenseLayer,输出为eraset,激活函数为hard_sigmoid,节点数:20,输入层为controller;

add:DenseLayer,输出为addt,激活函数为ClippedLinear(0,1),节点数:20,输入层为controller;

rectify:f(x)=max(0,x)

sign_add:DenseLayer,输出为signAddt,激活函数为ClippedLinear(-1,1),节点数:20,输入层为controller;

rectify:f(x)=max(0,x)

softmax:f(x)=exj∑Kk=1exk

hard_sigmoid:

f(x)=?????x=0,x<0x=0.2x+0.5,x∈[0,1]x=1,x>1

ClippedLinear(a,b):

f(x)={x=a,x<ax=b,x>b

3.copy-task实验

(待续)

参考文章:

http://blog.csdn.net/niuwei22007/article/details/49208643

https://medium.com/snips-ai/ntm-lasagne-a-library-for-neural-turing-machines-in-lasagne-2cdce6837315

http://lasagne.readthedocs.org/en/latest/user/tutorial.html

http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow

时间: 2024-08-03 15:40:10

Neural Turing Machines-NTM系列(三)ntm-lasagne源码分析的相关文章

RMI 系列(02)源码分析

目录 RMI 系列(02)源码分析 1. 架构 2. 服务注册 2.1 服务发布整体流程 2.2 服务暴露入口 exportObject 2.3 生成本地存根 2.4 服务监听 2.5 ObjectTable 注册与查找 2.6 服务绑定 2.7 总结 3. 服务发现 3.1 注册中心 Stub 3.2 普通服务 Stub RMI 系列(02)源码分析 1. 架构 RMI 中有三个重要的角色:注册中心(Registry).客户端(Client).服务端(Server). 图1 RMI 架构图 在

u-boot学习(三):u-boot源码分析

建立域模型和关系数据模型有着不同的出发点: 域模型: 由程序代码组成, 通过细化持久化类的的粒度可提高代码的可重用性, 简化编程 在没有数据冗余的情况下, 应该尽可能减少表的数目, 简化表之间的参照关系, 以便提高数据的访问速度 Hibernate 把持久化类的属性分为两种: 值(value)类型: 没有 OID, 不能被单独持久化, 生命周期依赖于所属的持久化类的对象的生命周期 实体(entity)类型: 有 OID, 可以被单独持久化, 有独立的生命周期(如果实体类型包含值类型,这个值类型就

计划在CSDN学院推出系列视频课程《源码分析教程5部曲》

?? 计划在CSDN学院推出系列视频课程<源码分析教程5部曲> 源码分析教程5部曲之1--漫游C语言 源码分析教程5部曲之2--C标准库概览 源码分析教程5部曲之3--libevent源码分析 源码分析教程5部曲之4--memcached源码分析 源码分析教程5部曲之5--redis源码分析

Java 序列化和反序列化(三)Serializable 源码分析 - 2

目录 Java 序列化和反序列化(三)Serializable 源码分析 - 2 1. ObjectStreamField 1.1 数据结构 1.2 构造函数 2. ObjectStreamClass Java 序列化和反序列化(三)Serializable 源码分析 - 2 在上一篇文章中围绕 ObjectOutputStream#writeObject 讲解了一下序列化的整个流程,这中间很多地方涉及到了 ObjectStreamClass 和 ObjectStreamField 这两个类.

Webpack-源码三,从源码分析如何写一个plugin

经过上一篇博客分析webpack从命令行到打包完成的整体流程,我们知道了webpage的plugin是基于事件机制工作的,这样最大的好处是易于扩展.社区里很多webpack的plugin,但是具体到我们的项目并不一定适用,这篇博客告诉你如何入手写一个plugin,然后分析源码相关部分告诉你你的plugin是如何工作.知其然且知其所以然. 该系列博客的所有测试代码. 从黑盒角度学习写一个plugin 所谓黑盒,就是先不管webpack的plugin如何运作,只去看官网介绍. Compiler和Co

Java入门系列之集合HashMap源码分析(十四)

前言 我们知道在Java 8中对于HashMap引入了红黑树从而提高操作性能,由于在上一节我们已经通过图解方式分析了红黑树原理,所以在接下来我们将更多精力投入到解析原理而不是算法本身,HashMap在Java中是使用比较频繁的键值对数据类型,所以我们非常有必要详细去分析背后的具体实现原理,无论是C#还是Java原理解析,从不打算一行行代码解释,我认为最重要的是设计思路,重要的地方可能会多啰嗦两句. HashMap原理分析 我们由浅入深,循序渐进,首先了解下在HashMap中定义的几个属性,稍后会

Struts2【三】 StrutsPrepareAndExecuteFilter 源码分析&lt;一&gt;

先把关键的类总体一览一下 用JadClipse反编译debug源码 都知道Filter三个方法,init,doFilter,destory 先看init方法初始化了什么 先按名字记住几个关键类,initOperation初始化处理器,Dispatcher派发器,PrepareOperations预处理器,ExecuteOperations执行处理器 55.FilterHostConfig包装了FilterConfig 56.nit.initLogging不用管,这个貌似是过滤器初始化参数指定的日

三)CodeIgniter源码分析之Common.php

1 <?php if ( ! defined('BASEPATH')) exit('No direct script access allowed'); 2 3 // ------------------------------------------------------------------------ 4 5 /** 6 * Common Functions 7 */ 8 9 /** 10 * 为什么还要定义这些全局函数呢?比如说,下面有很多函数,如get_config().confi

TCP三次握手源码分析

TCP握手分为三个阶段,在握手开始之前,通信双方的套接字状态均为“TCP_CLOSE”,以下是这三个阶段: (1)客户端发送一个标志位中SYN位为1的报文给服务端,并设套接字状态为“TCP_SYNSENT” (2)服务端接到SYN报文,设套接字状态为“TCP_SYNRCV”,并回送一个SYN+ACK位均为1的报文 (3)客户端接到SYN+ACK报文,回送一个ACK位为1的报文,设套接字状态为“TCP_ESTABLISHED”,服务端接到ACK报文后,同样设置为“TCP_ESTABLISHED”