[DL] CNN源码分析

在Hinton的教程中, 使用Python的theano库搭建的CNN是其中重要一环, 而其中的所谓的SGD - stochastic gradient descend算法又是如何实现的呢? 看下面源码(篇幅考虑只取测试模型函数, 训练函数只是多了一个updates参数):

 3     classifier = LogisticRegression(input=x, n_in=24 * 48, n_out=32)

 7     cost = classifier.negative_log_likelihood(y)

11     test_model = theano.function(inputs=[index],
12             outputs=classifier.errors(y),
13             givens={
14                 x: test_set_x[index * batch_size: (index + 1) * batch_size],
15                 y: test_set_y[index * batch_size: (index + 1) * batch_size]})

行3声明了一个对象classifer, 它的输入是符号x, 大小为24*48, 输出长度为32.

行11定义了一个theano的函数对象, 接收的是下标index, 使用输入数据的第index*batch_size~第(index+1)*batch_size个数据作为函数的输入, 输出为误差.

我们再来看看行12中的errors函数的定义:

    def errors(self, y):
        # check if y has same dimension of y_pred
        if y.ndim != self.y_pred.ndim:
            raise TypeError(‘y should have the same shape as self.y_pred‘,
                (‘y‘, target.type, ‘y_pred‘, self.y_pred.type))
        # check if y is of the correct datatype
        if y.dtype.startswith(‘int‘):
            # the T.neq operator returns a vector of 0s and 1s, where 1
            # represents a mistake in prediction
            return T.mean(T.neq(self.y_pred, y))
        else:
            raise NotImplementedError()

self.y_pred 是一个大小为batch_size的向量, 每个元素代表batch_size中对应输入的网络判断结果, errors函数接受1个同等大小的期望输出y, 将两者进行比较求差后作均值返回, 这正是误差的定义.

那么问题来了, 这个 self.y_pred 是如何计算的? 这里我们看LogisticRegression的构造函数:

 1     def __init__(self, input, n_in, n_out):
 2
 3         # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
 4         self.W = theano.shared(value=numpy.zeros((n_in, n_out),
 5                                                  dtype=theano.config.floatX),
 6                                 name=‘W‘, borrow=True)
 7         # initialize the baises b as a vector of n_out 0s
 8         self.b = theano.shared(value=numpy.zeros((n_out,),
 9                                                  dtype=theano.config.floatX),
10                                name=‘b‘, borrow=True)
11
12         # compute vector of class-membership probabilities in symbolic form
13         self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)
14
15         # compute prediction as class whose probability is maximal in
16         # symbolic form
17         self.y_pred = T.argmax(self.p_y_given_x, axis=1)
18
19         # parameters of the model
20         self.params = [self.W, self.b]
时间: 2024-10-11 11:01:55

[DL] CNN源码分析的相关文章

FastText总结,fastText 源码分析

文本分类单层网络就够了.非线性的问题用多层的. fasttext有一个有监督的模式,但是模型等同于cbow,只是target变成了label而不是word. fastText有两个可说的地方:1 在word2vec的基础上, 把Ngrams也当做词训练word2vec模型, 最终每个词的vector将由这个词的Ngrams得出. 这个改进能提升模型对morphology的效果, 即"字面上"相似的词语distance也会小一些. 有人在question-words数据集上跑过fastT

OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波

http://blog.csdn.net/chenyusiyuan/article/details/8710462 OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 2013-03-23 17:44 16963人阅读 评论(28) 收藏 举报 分类: 机器视觉(34) 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[+] KAZE系列笔记: OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 OpenCV学习笔记(28)KA

GNU GRUB 2.00 源码分析笔记,持续更新

前言 很多运维类书籍或文章仅从系统管理者的角度讲解了 grub 的安装以及使用, 本篇博文则从 gnu grub 2.00 的源码入手,从开发者,以及系统底层运行机制的角度,分析 grub 是如何作为跨平台的"全面统一的引导加载程序",来引导操作系统,加载 Linux 内核的过程等等, 部分内容参考了<深度探索 Linux 操作系统>一书中相关的内容(ISBN 978-7-11143901-1 )以及 gnu grub 项目官方站点的文档,并且加入自己分析源码时的笔记. (

TeamTalk源码分析之login_server

login_server是TeamTalk的登录服务器,负责分配一个负载较小的MsgServer给客户端使用,按照新版TeamTalk完整部署教程来配置的话,login_server的服务端口就是8080,客户端登录服务器地址配置如下(这里是win版本客户端): 1.login_server启动流程 login_server的启动是从login_server.cpp中的main函数开始的,login_server.cpp所在工程路径为server\src\login_server.下表是logi

Android触摸屏事件派发机制详解与源码分析二(ViewGroup篇)

1 背景 还记得前一篇<Android触摸屏事件派发机制详解与源码分析一(View篇)>中关于透过源码继续进阶实例验证模块中存在的点击Button却触发了LinearLayout的事件疑惑吗?当时说了,在那一篇咱们只讨论View的触摸事件派发机制,这个疑惑留在了这一篇解释,也就是ViewGroup的事件派发机制. PS:阅读本篇前建议先查看前一篇<Android触摸屏事件派发机制详解与源码分析一(View篇)>,这一篇承接上一篇. 关于View与ViewGroup的区别在前一篇的A

HashMap与TreeMap源码分析

1. 引言     在红黑树--算法导论(15)中学习了红黑树的原理.本来打算自己来试着实现一下,然而在看了JDK(1.8.0)TreeMap的源码后恍然发现原来它就是利用红黑树实现的(很惭愧学了Java这么久,也写过一些小项目,也使用过TreeMap无数次,但到现在才明白它的实现原理).因此本着"不要重复造轮子"的思想,就用这篇博客来记录分析TreeMap源码的过程,也顺便瞅一瞅HashMap. 2. 继承结构 (1) 继承结构 下面是HashMap与TreeMap的继承结构: pu

Linux内核源码分析--内核启动之(5)Image内核启动(rest_init函数)(Linux-3.0 ARMv7)【转】

原文地址:Linux内核源码分析--内核启动之(5)Image内核启动(rest_init函数)(Linux-3.0 ARMv7) 作者:tekkamanninja 转自:http://blog.chinaunix.net/uid-25909619-id-4938395.html 前面粗略分析start_kernel函数,此函数中基本上是对内存管理和各子系统的数据结构初始化.在内核初始化函数start_kernel执行到最后,就是调用rest_init函数,这个函数的主要使命就是创建并启动内核线

Spark的Master和Worker集群启动的源码分析

基于spark1.3.1的源码进行分析 spark master启动源码分析 1.在start-master.sh调用master的main方法,main方法调用 def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) val (actorSystem, _, _, _) =

Solr4.8.0源码分析(22)之 SolrCloud的Recovery策略(三)

Solr4.8.0源码分析(22)之 SolrCloud的Recovery策略(三) 本文是SolrCloud的Recovery策略系列的第三篇文章,前面两篇主要介绍了Recovery的总体流程,以及PeerSync策略.本文以及后续的文章将重点介绍Replication策略.Replication策略不但可以在SolrCloud中起到leader到replica的数据同步,也可以在用多个单独的Solr来实现主从同步.本文先介绍在SolrCloud的leader到replica的数据同步,下一篇