(数据科学学习手札39)RNN与LSTM基础内容详解

一、简介

  循环神经网络(recurrent neural network,RNN),是一类专门用于处理序列数据(时间序列、文本语句、语音等)的神经网络,尤其是可以处理可变长度的序列;在与传统的时间序列分析进行比较的过程之中,RNN因为其梯度弥散等问题对长序列表现得不是很好,而据此提出的一系列变种则展现出很明显的优势,最具有代表性的就是LSTM(long short-term  memory),而本文就从标准的循环神经网络结构和原理出发,再到LSTM的网络结构和原理,对其有一个基本的认识和阐述;

二、关于基本的RNN

基本结构:

  循环神经网络又叫递归神经网络,因为其向前传播过程中折叠了一个循环计算的重复结构,这里我们先观察一个经典的动态系统,即:

其中s(t)为系统在t时刻的状态,和传统时间序列分析中的模型类似,在有限时间步τ的条件下,经过τ-1次上述展开过程就可以完全展开这个有限时间步内的过程,以τ=3为例:

上述过程可以用图论中的有向无环计算图来表示:

每一个时刻的状态都经由函数f映射到下一个时刻,而这是仅有自我状态驱动的系统,我们再考虑引入外部信号x(t)的系统:

即对于一个序列,其当前状态包含了过去所有时刻状态对其的影响,以及当前时刻外部信号的影响,我们的循环神经网络就是建立在上述知识的基础上,因为RNN中的状态即是网络的隐藏单元,我们用h来重新定义上式:

则一个最简单典型的RNN架构如下(未包含输出层部分),左边是循环计算部分未展开的结构,右边是展开后的结构:

其中左边的黑色方块表示单个时间步的延迟,可以类比时间序列分析中的n阶延迟,接着我们添加上输出层以及不同层之间的连接权信息,便得到下面这张经典RNN的结构示意图:

由上图,在这个将输入序列x映射到输出值o的过程中,层与层之间通过连接权进行映射,并在功能神经元内部进行激活(通常是tanh激活函数),其中在分类任务时,ho的映射由softmax完成,接着与真实的label,即y进行比较计算出损失L,总结一下经典RNN结构的特点;

  1、每个时间步完成后都有输出,且时间步之间有按照时序顺序的循环连接,这也决定了RNN的向后传播过程不同于传统BP算法可以并行,RNN在一个未展开的时间步内部只能按顺序调整参数,即通过时间反向传播算法(back-propagation through time,BPTT);

  2、不同的任务决定了不同的输出方式,如翻译就是序列到序列,分类或时序预测就是在最后一次得到输出;

  3、参数共享

前向传播:

  在输出为离散的情况下,上述经典RNN的前向传播过程如下:

  1、时刻t的隐藏状态h(t)

  2、时刻t的输出o(t)

  3、时刻t的预测类别输出:

  4、损失函数,离散分类任务时通常为对数似然函数,连续预测任务通常是均方误差:

三、关于LSTM

  RNN在实际使用过程中,在处理较长序列输入时,难以传递相隔较远的信息,究其原因,我们先回想一下RNN的基本结构,其真正的输入有两部分——来自序列第t个位置的输入xt,和来自上一个隐层的输出ht-1,考虑隐层的信息往后传导的过程,这里令RNN中隐层连接下一个时刻隐层的权重为Whh,不考虑每一次隐层的非线性激活时,从初始状态h0到第t时刻状态ht,其信息传递的过程如下,其中对Whh的连乘部分做了特征分解:

当特征值小于1时,连续相乘的结果是特征值向0方向衰减;当特征值大于1时,连续相乘的结果是特征值向∞方向增长。这两种情况都会导致较远时刻状态的信息消失(vanish)或爆炸(explode),无法有效地反馈到t时刻;

  上述情况导致的结果是我们的RNN网络难以通过梯度下降进行有效的学习,为了有效地利用梯度下降法来进行学习,我们需要控制传递过程中梯度的积在1左右,目前最有效的方式是gated RNNs,而LSTM就是其中的一个代表;

  再次回想前面的RNN中的t时刻状态计算过程,其中σ为激活函数,通常为tanh

而LSTM就是在RNN的基础上施加了若干个门(gate)来控制,我们先看LSTM的示意图即网络结构中涉及的计算内容,然后在接下来的过程中逐一解释:

且这些门均由Sigmoid型函数激活,具体如下:

  1、遗忘门(forget gate)

  这个gate控制对上一层的cell状态ct-1中的信息保留多少,它流入当前时刻xt与上一时刻传递过来的状态ht-1通过对应的所有事件步共享的权重WxfWhf,偏移bf来进行线性组合,并通过sigmoid函数进行处理后得到当前时刻遗忘门输出ft,即下式:

  2、输入门(input gate)

  输入门控制了有多少信息可以流入cell,即上图中it的部分(所谓at的部分其实就是经典RNN中的输入层)它对应了下式:

  3、输出门(output gate)

  输出门顾名思义,控制了有多少当前时刻的cell中的信息可以流向当前隐藏状态ht,与经tanh处理的ct进行哈达玛相乘得到ht,对应下式:

  4、t时刻ct的更新

  如上图,我们这一个时间步的cell中的ct为遗忘门处理后的上一时刻中的ct-1、输入门控制流入的信息it经典RNN中的输入层信息at等信息的汇总,计算过程对应着:

  5、t时刻ht的更新

  如上图所示,LSTM新加的这些结构的作用就是为了调整ht使其在长时间步的传递过程中减少信息失效的可能,对应的新的ht

  而其他部分的计算内容就同RNN,即LSTM就是一个扩充了数倍调整过滤参数的RNN,以上就是本篇文章的基本内容,如有笔误,望指出。

参考文献:

《深度学习》

《Yjango的循环神经网络》https://zhuanlan.zhihu.com/p/25518711

原文地址:https://www.cnblogs.com/feffery/p/9116997.html

时间: 2024-08-06 20:53:04

(数据科学学习手札39)RNN与LSTM基础内容详解的相关文章

(数据科学学习手札32)Python中re模块的详细介绍

一.简介 关于正则表达式,我在前一篇(数据科学学习手札31)中已经做了详细介绍,本篇将对Python中自带模块re的常用功能进行总结: re作为Python中专为正则表达式相关功能做出支持的模块,提供了一系列方法来完成几乎全部类型的文本信息的处理工作,下面一一介绍: 二.re.compile() 在前一篇文章中我们使用过这个方法,它通过编译正则表达式参数,来返回一个目标对象的匹配模式,进而提高了正则表达式的效率,主要参数如下: pattern:输入的欲编译正则表达式,需将正则表达式包裹在''内传

(数据科学学习手札47)基于Python的网络数据采集实战(2)

一.简介 马上大四了,最近在暑期实习,在数据挖掘的主业之外,也帮助同事做了很多网络数据采集的内容,接下来的数篇文章就将一一罗列出来,来续写几个月前开的这个网络数据采集实战的坑. 二.马蜂窝评论数据采集实战 2.1 数据要求 这次我们需要采集的数据是知名旅游网站马蜂窝下重庆区域内所有景点的用户评论数据,如下图所示: 思路是,先获取所有景点的poi ID,即每一个景点主页url地址中的唯一数字: 这一步和(数据科学学习手札33)基于Python的网络数据采集实战(1)中做法类似,即在下述界面: 翻页

(数据科学学习手札55)利用ggthemr来美化ggplot2图像

一.简介 R中的ggplot2是一个非常强大灵活的数据可视化包,熟悉其绘图规则后便可以自由地生成各种可视化图像,但其默认的色彩和样式在很多时候难免有些过于朴素,本文将要介绍的ggthemr包专门针对原生ggplot2图像进行美化,掌握它之后你就可以创作出更具特色和美感的数据可视化作品. 二.基础内容 2.1 安装 不同于常规的R包,ggthemr并没有在CRAN上发布,因此我们需要使用devtools中的install_github()直接从github上安装它,参照github上ggthemr

(数据科学学习手札81)conda+jupyter玩转数据科学环境搭建

本文示例yaml文件已上传至我的Github仓库https://github.com/CNFeffery/DataScienceStudyNotes 1 简介 我们在使用Python进行数据分析时,很多时候都在解决环境搭建的问题,不同版本.依赖包等问题经常给数据科学工作流的搭建和运转带来各种各样令人头疼的问题,本文就将基于笔者自己摸索出的经验,以geopandas环境的搭建为例,教你使用conda+jupyter轻松搞定环境的搭建.管理与拓展. 图1 2 虚拟环境的搭建与使用 2.1 使用con

(数据科学学习手札70)面向数据科学的Python多进程简介及应用

本文对应脚本已上传至我的Github仓库https://github.com/CNFeffery/DataScienceStudyNotes 一.简介 进程是计算机系统中资源分配的最小单位,也是操作系统可以控制的最小单位,在数据科学中很多涉及大量计算.CPU密集型的任务都可以通过多进程并行运算的方式大幅度提升运算效率从而节省时间开销,而在Python中实现多进程有多种方式,本文就将针对其中较为易用的几种方式进行介绍. 二.利用multiprocessing实现多进程 multiprocessin

(数据科学学习手札23)决策树分类原理详解&Python与R实现

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法.由于这种决策分支画成图形很像一棵树的枝干,故称决策树.在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系. 一.初识决策树 决策树是一种树形结构,一般的,一棵决策树包含一个根结点,若干个内部结点和若干个叶结点: 叶结点:树的一个方向的最末端,表示结果的输出: 根结点:初始样

(数据科学学习手札19)R中基本统计分析技巧总结

在获取数据,并且完成数据的清洗之后,首要的事就是对整个数据集进行探索性的研究,这个过程中会利用到各种描述性统计量和推断性统计量来初探变量间和变量内部的基本关系,本篇笔者便基于R,对一些常用的数据探索方法进行总结: 1.描述性统计量部分 1.1 计算描述性统计量的常规方法 summary() summary()函数提供了最小值.最大值.四分位数和数值型变量的均值,以及因子向量和逻辑型向量的频数统计: > #挂载鸢尾花数据 > data(iris) > #计算鸢尾花各变量的基本描述统计量 &

(数据科学学习手札62)详解seaborn中的kdeplot、rugplot、distplot与jointplot

一.简介 seaborn是Python中基于matplotlib的具有更多可视化功能和更优美绘图风格的绘图模块,当我们想要探索单个或一对数据分布上的特征时,可以使用到seaborn中内置的若干函数对数据的分布进行多种多样的可视化,本文以jupyter notebook为编辑工具,针对seaborn中的kdeplot.rugplot.distplot和jointplot,对其参数设置和具体用法进行详细介绍. 二.kdeplot seaborn中的kdeplot可用于对单变量和双变量进行核密度估计并

(数据科学学习手札65)利用Python实现Shp格式向GeoJSON的转换

一.简介 Shp格式是GIS中非常重要的数据格式,主要在Arcgis中使用,但在进行很多基于网页的空间数据可视化时,通常只接受GeoJSON格式的数据,众所周知JSON(JavaScript Object Nonation)是利用键值对+嵌套来表示数据的一种格式,以其轻量.易解析的优点,被广泛使用与各种领域,而GeoJSON就是指在一套规定的语法规则下用JSON格式存储矢量数据,本文就将针对GeoJSON的语法规则,以及如何利用Python完成Shp格式到GeoJSON格式的转换进行介绍. 二.