MatConvNet中mnist源码解析

  本文的代码来自MatConvNet

  下面是自己对代码的注释:  

cnn_mnist_init.m

function net = cnn_mnist_init(varargin)
% CNN_MNIST_LENET Initialize a CNN similar for MNIST
opts.useBatchNorm = true ;   #batchNorm是否使用
opts.networkType = ‘simplenn‘ ;  #网络结构使用lenet结构
opts = vl_argparse(opts, varargin) ;
rng(‘default‘);
rng(0) ;
f=1/100 ;
net.layers = {} ;
# 定义各层参数,type是网络的层属性,stride为步长,pad为填充
# method中max为最大池化
net.layers{end+1} = struct(‘type‘, ‘conv‘, ...
                           ‘weights‘, {{f*randn(5,5,1,20, ‘single‘), zeros(1, 20, ‘single‘)}}, ...
                           ‘stride‘, 1, ...
                           ‘pad‘, 0) ;
net.layers{end+1} = struct(‘type‘, ‘pool‘, ...
                           ‘method‘, ‘max‘, ...
                           ‘pool‘, [2 2], ...
                           ‘stride‘, 2, ...
                           ‘pad‘, 0) ;
net.layers{end+1} = struct(‘type‘, ‘conv‘, ...
                           ‘weights‘, {{f*randn(5,5,20,50, ‘single‘),zeros(1,50,‘single‘)}}, ...
                           ‘stride‘, 1, ...
                           ‘pad‘, 0) ;
net.layers{end+1} = struct(‘type‘, ‘pool‘, ...
                           ‘method‘, ‘max‘, ...
                           ‘pool‘, [2 2], ...
                           ‘stride‘, 2, ...
                           ‘pad‘, 0) ;
net.layers{end+1} = struct(‘type‘, ‘conv‘, ...
                           ‘weights‘, {{f*randn(4,4,50,500, ‘single‘),  zeros(1,500,‘single‘)}}, ...
                           ‘stride‘, 1, ...
                           ‘pad‘, 0) ;
net.layers{end+1} = struct(‘type‘, ‘relu‘) ;
net.layers{end+1} = struct(‘type‘, ‘conv‘, ...
                           ‘weights‘, {{f*randn(1,1,500,10, ‘single‘), zeros(1,10,‘single‘)}}, ...
                           ‘stride‘, 1, ...
                           ‘pad‘, 0) ;
net.layers{end+1} = struct(‘type‘, ‘softmaxloss‘) ;

# optionally switch to batch normalization
# BN层一般用在卷积到池化过程中,激活函数前面,这里是在第1,4,7层插入BN
if opts.useBatchNorm
  net = insertBnorm(net, 1) ;
  net = insertBnorm(net, 4) ;
  net = insertBnorm(net, 7) ;
end

# Meta parameters
net.meta.inputSize = [27 27 1] ;  #输入大小[w,h,channel],这里是灰度图片,单通道为1
net.meta.trainOpts.learningRate = 0.001 ; #学习率
net.meta.trainOpts.numEpochs = 20 ; #epoch次数,注意这里不是所谓的迭代次数
net.meta.trainOpts.batchSize = 100 ; #批处理,这里就是mini-batchsize,batchSize大小对训练过程的影响见我另外一篇博客:卷积神经网络四大问题之一

# Fill in defaul values
net = vl_simplenn_tidy(net) ;

# Switch to DagNN if requested
# 选择不同的网络结构,这里就使用的simplenn结构
switch lower(opts.networkType)
  case ‘simplenn‘
    % done
  case ‘dagnn‘
    net = dagnn.DagNN.fromSimpleNN(net, ‘canonicalNames‘, true) ;
    net.addLayer(‘error‘, dagnn.Loss(‘loss‘, ‘classerror‘), ...
             {‘prediction‘,‘label‘}, ‘error‘) ;
  otherwise
    assert(false) ;
end

% --------------------------------------------------------------------
function net = insertBnorm(net, l)   #具体的BN函数
% --------------------------------------------------------------------
assert(isfield(net.layers{l}, ‘weights‘));
ndim = size(net.layers{l}.weights{1}, 4);
layer = struct(‘type‘, ‘bnorm‘, ...
               ‘weights‘, {{ones(ndim, 1, ‘single‘), zeros(ndim, 1, ‘single‘)}}, ...
               ‘learningRate‘, [1 1 0.05], ...
               ‘weightDecay‘, [0 0]) ;
net.layers{l}.biases = [] ;
net.layers = horzcat(net.layers(1:l), layer, net.layers(l+1:end)) ; #horzcat水平方向矩阵连接,这里就是重新构建网络结构,将BN层插入到lennt中

cnn_mnist_experiments.m

%% Experiment with the cnn_mnist_fc_bnorm
[net_bn, info_bn] = cnn_mnist(...
  ‘expDir‘, ‘data/mnist-bnorm‘, ‘useBnorm‘, true);

[net_fc, info_fc] = cnn_mnist(...
  ‘expDir‘, ‘data/mnist-baseline‘, ‘useBnorm‘, false);
# 以下就是画图的代码
figure(1) ; clf ;
subplot(1,2,1) ;  # 第一张图
semilogy(info_fc.val.objective‘, ‘o-‘) ; hold all ;
semilogy(info_bn.val.objective‘, ‘+--‘) ;  #表示y坐标轴是对数坐标系
xlabel(‘Training samples [x 10^3]‘); ylabel(‘energy‘) ;
grid on ; #加入网格
h=legend(‘BSLN‘, ‘BNORM‘) ;  #加入标注
set(h,‘color‘,‘none‘);
title(‘objective‘) ;
subplot(1,2,2) ;
plot(info_fc.val.error‘, ‘o-‘) ; hold all ;
plot(info_bn.val.error‘, ‘+--‘) ;
h=legend(‘BSLN-val‘,‘BSLN-val-5‘,‘BNORM-val‘,‘BNORM-val-5‘) ;
grid on ;
xlabel(‘Training samples [x 10^3]‘); ylabel(‘error‘) ;
set(h,‘color‘,‘none‘) ;
title(‘error‘) ;
drawnow ;

  运行结果得到的图:

    

时间: 2024-11-29 11:32:26

MatConvNet中mnist源码解析的相关文章

DeepLearnToolBox中CNN源码解析

DeepLearnToolbox是一个简单理解CNN过程的工具箱,可以在github下载.为了理解卷积神经网络的过程,我特此对CNN部分源码进行了注释.公式的计算可以由上一篇blog推导得出. 注意:代码中没有的subsampling进行设置参数,将subsampling层的参数w就设置为了0.25,而偏置参数b设置为0.卷积层计算过程为上一层所有feature map的卷积的结果和,后再加一个偏置,再取sigmoid函数.而subsampling的计算过程为上一层对应的2*2的feature

Mybatis 中sqlsession源码解析

一.sqlsession获取过程 1.基础配置 在mybatis框架下进行的数据库操作都需要首先获取sqlsession,在mybatis与spring集成后获取sqlsession需要用到sqlsessionTemplate这个类. 首先在spring对sqlsessionTemplate进行配置,使用到的是 org.mybatis.spring.SqlSessionTemplate 这个类. <!-- SqlSession实例 --> <bean id="sessionTe

Android 热修复Nuwa的原理及Gradle插件源码解析

现在,热修复的具体实现方案开源的也有很多,原理也大同小异,本篇文章以Nuwa为例,深入剖析. Nuwa的github地址 https://github.com/jasonross/Nuwa 以及用于hotpatch生成的gradle插件地址 https://github.com/jasonross/NuwaGradle 而Nuwa的具体实现是根据QQ空间的热修复方案来实现的.安卓App热补丁动态修复技术介绍.在阅读本篇文章之前,请先阅读该文章. 从QQ空间终端开发团队的文章中可以总结出要进行热更

Spring 源码解析之DispatcherServlet源码解析(五)

Spring 源码解析之DispatcherServlet源码解析(五) 前言 本文需要有前四篇文章的基础,才能够清晰易懂,有兴趣可以先看看详细的流程,这篇文章可以说是第一篇文章,也可以说是前四篇文章的的汇总,Spring的整个请求流程都是围绕着DispatcherServlet进行的 类结构图 根据类的结构来说DispatcherServlet本身也是继承了HttpServlet的,所有的请求都是根据这一个Servlet来进行转发的,同时解释了为什么需要在web.xml进行如下配置,因为Spr

eclipse中导入jdk源码、SpringMVC注解@RequestParam、SpringMVC文件上传源码解析、ajax上传excel文件

eclipse中导入jdk源码:http://blog.csdn.net/evolly/article/details/18403321, http://www.codingwhy.com/view/799.html. ------------------------------- SpringMVC注解@RequestParam:http://825635381.iteye.com/blog/2196911. --------------------------- SpringMVC文件上传源

Scala 深入浅出实战经典 第48讲:Scala类型约束代码实战及其在Spark中的应用源码解析

王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 腾讯微云:http://url.cn/TnGbdC 360云盘:http://yunpan.cn/cQ4c2UALDjSKy 访问密码 45e2 技术爱好者尤其是大数据爱好者 可以加DT大数据梦工厂的qq群 DT大数据梦工厂① :462923555 DT大数据梦工厂②:437123764 DT大数据梦工厂③

68:Scala并发编程原生线程Actor、Cass Class下的消息传递和偏函数实战解析及其在Spark中的应用源码解析

今天给大家带来的是王家林老师的scala编程讲座的第68讲:Scala并发编程原生线程Actor.Cass Class下的消息传递和偏函数实战解析 昨天讲了Actor的匿名Actor及消息传递,那么我们今天来看一下原生线程Actor及CassClass下的消息传递,让我们从代码出发: case class Person(name:String,age:Int)//定义cass Class class HelloActor extends Actor{//预定义一个Actor  def act()

安卓中的事件分发机制源码解析

安卓中的事件分发机制主要涉及到两类控件,一类是容器类控件ViewGroup,如常用的布局控件,另一类是显示类控件,即该控件中不能用来容纳其它控件,它只能用来显示一些资源内容,如Button,ImageView等控件.暂且称前一类控件为ViewGroup类控件(尽管ViewGroup本身也是一个View),后者为View类控件. 安卓中的事件分发机制主要涉及到dispatchTouchEvent(MotionEvent ev).onInterceptTouchEvent(MotionEvent e

Scala 深入浅出实战经典 第60讲:Scala中隐式参数实战详解以及在Spark中的应用源码解析

王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 腾讯微云:http://url.cn/TnGbdC 360云盘:http://yunpan.cn/cQ4c2UALDjSKy 访问密码 45e2土豆:http://www.tudou.com/programs/view/IVN4EuFlmKk/优酷:http://v.youku.com/v_show/id_