【转】Caffe源码阅读(1) 全连接层

原作者:在路上 原文链接:http://zhangliliang.com/2014/09/15/about-caffe-code-full-connected-layer/

今天看全连接层的实现。
主要看的是https://github.com/BVLC/caffe/blob/master/src/caffe/layers/inner_product_layer.cpp

主要是三个方法,setup,forward,backward

  • setup 初始化网络参数,包括了w和b
  • forward 前向传播的实现
  • backward 后向传播的实现

setup

主体的思路,作者的注释给的很清晰。
主要是要弄清楚一些变量对应的含义

123
M_ 表示的样本数K_ 表示单个样本的特征长度N_ 表示输出神经元的个数

为了打字方便,以下省略下划线,缩写为M,K,N

forward

实现的功能就是 y=wx+b

1234
x为输入,维度 MxKy为输出,维度 Nx1w为权重,维度 NxKb为偏置,维度 Nx1

具体到代码实现,用的是这个函数caffe_cpu_gemm,具体的函数头为

1234
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,    const float alpha, const float* A, const float* B, const float beta,    float* C)

略长,整理它的功能其实很直观,即C←αA×B+βC

1234567891011121314
const CBLAS_TRANSPOSE TransA  # A是否转置const CBLAS_TRANSPOSE TransB  # B是否转置

# 这部分都比较直观不用解释了const int M          const int Nconst int Kconst float alphaconst float* Aconst float* Bconst float beta,float* C

# 其中A维度是MxK,B维度是KxN,C维度为MxN

从实际代码来算,全连接层的forward包括了两步:

12345678
# 这一步表示 y←wx,或者说是y←xw‘caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,      bottom_data, weight, (Dtype)0., top_data);# 这一步表示 y←y+bcaffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,        bias_multiplier_.cpu_data(),        this->blobs_[1]->cpu_data(), (Dtype)1., top_data);# 所以两步连起来就等价于y=wx+b

backward

分成三步:

  • 更新w
  • 更新b
  • 计算delta

用公式来说是下面三条:

一步步来,先来第一步,更新w,对应代码是:

12
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,        top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());

对照公式,有

123
需要更新的w的梯度的维度是NxK公式中的a^(l)_j对应的是bottom_data,维度是KxM公式中的\delta_(l+1)_i对应的是top_diff,维度是NxM

然后是第二步,更新b,对应代码是:

123
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,        bias_multiplier_.cpu_data(), (Dtype)0.,        this->blobs_[1]->mutable_cpu_diff());

这里用到了caffe_cpu_gemv,简单来说跟上面的caffe_cpu_gemm类似,不过前者是计算矩阵和向量之间的乘法的(从英文命名可以分辨,v for vector, m for matrix)。函数头:

12345678910111213141516171819
void caffe_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,    const int N, const float alpha, const float* A, const float* x,    const float beta, float* y) 

# 实现的功能类似 Y←αAX + βY# 其中A的维度为 MxN# X是一个向量,维度为 Nx1# Y是结果 ,也是一个向量,维度为Mx1

const CBLAS_TRANSPOSE TransA  # 是否对A进行转置

# 下面的参数很直观,不描述了const int Mconst int Nconst float alphaconst float* Aconst float* xconst float betafloat* y

绕回到具体的代码实现。。如何更新b?根据公式b的梯度直接就是delta

1234
# 所以对应的代码其实就是将top_diff转置后就可以了(忽略乘上bias_multiplier这步)caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,        bias_multiplier_.cpu_data(), (Dtype)0.,        this->blobs_[1]->mutable_cpu_diff());

第三步是计算delta,对应公式

这里面可以忽略掉最后一项f’,因为在caffe实现中,这是由Relu layer来实现的,这里只需要实现括号里面的累加就好了,这个累加其实可以等价于矩阵乘法

1234567
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,        top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,        (*bottom)[0]->mutable_cpu_diff());

# top_diff为\delta^(l+1)_j 维度 MxN# this->blobs_[0]->cpu_data()为W^(l)_ji 维度 NxK# (*bottom)[0]->mutable_cpu_diff()是要计算的结果,也就是\delta^(l)_i 维度是MxK
时间: 2024-10-07 22:54:31

【转】Caffe源码阅读(1) 全连接层的相关文章

caffe源码阅读4-layer.hpp

An interface for the units of computation which can be composed into a Net. Layer&s must implement a Forward function, in which they take their input (bottom) Blob&s (if any) and compute their output Blob&s (if any). They may also implement a

Caffe 源码阅读(二) 卷积层

背景: 项目中需要在 caffe 中增加 binary convolution layer, 所以在单步调试了 minist 的训练,大致看了一下流程,就详细看 convolution layer 了. 1.数据结构 caffe 的基本数据结构是 Blob,也就是数据流的基本结构. 2.网络结构 Net 是 Layer 构造出来的,Layer 包括了数据和运算(Blob input, Blob output, operation). 3.卷积 3.1 卷积意义 卷积实际上是一种线性运算,只不过用

读caffe源码(未完待续)

caffe源码阅读杂记 准备 一些参考网页 Neural Networks and Deep Learning TUTORIAL ON DEEP LEARNING FOR VISION Deep Learning Tutorial 知乎-深度学习caffe的代码怎么读 Caffe源码解析 caffe源码结构 官方代码结构doxygen 官方Caffe Tutorial 以C++源码形式配置debug&CPU版的caffe,便于阅读源码与单步调试[参考] 参考官方的文档,先了解某个模块的作用 为了

源码阅读经验谈-slim,darknet,labelimg,caffe(1)

本文首先谈自己的源码阅读体验,然后给几个案例解读,选的例子都是比较简单.重在说明我琢磨的点线面源码阅读方法.我不是专业架构师,是从一个深度学习算法工程师的角度来谈的,不专业的地方请大家轻拍. 经常看别人写的代码,然后改别人的代码,然后实现自己的想法,我想这是我们coder常干的事情.看人看代码,代码如人.他代码写的有多清爽简洁,说明他思维是清晰的:代码的结构有多合理,模块化内聚如何,是否低耦合,反应他的宏观把控能力.一个软件系统你可以把他看成是一个简单的企业,各个职能部门如何发挥自己的作用,相当

淘宝数据库OceanBase SQL编译器部分 源码阅读--生成物理查询计划

SQL编译解析三部曲分为:构建语法树,制定逻辑计划,生成物理执行计划.前两个步骤请参见我的博客<<淘宝数据库OceanBase SQL编译器部分 源码阅读--解析SQL语法树>>和<<淘宝数据库OceanBase SQL编译器部分 源码阅读--生成逻辑计划>>.这篇博客主要研究第三步,生成物理查询计划. 一. 什么是物理查询计划 与之前的阅读方法一致,这篇博客的两个主要问题是what 和how.那么什么是物理查询计划?物理查询计划能够直接执行并返回数据结果数

Netty源码阅读(一) ServerBootstrap启动

Netty源码阅读(一) ServerBootstrap启动 转自我的Github Netty是由JBOSS提供的一个java开源框架.Netty提供异步的.事件驱动的网络应用程序框架和工具,用以快速开发高性能.高可靠性的网络服务器和客户端程序.本文讲会对Netty服务启动的过程进行分析,主要关注启动的调用过程,从这里面进一步理解Netty的线程模型,以及Reactor模式. 这是我画的一个Netty启动过程中使用到的主要的类的概要类图,当然是用到的类比这个多得多,而且我也忽略了各个类的继承关系

【原】FMDB源码阅读(二)

[原]FMDB源码阅读(二) 本文转载请注明出处 -- polobymulberry-博客园 1. 前言 上一篇只是简单地过了一下FMDB一个简单例子的基本流程,并没有涉及到FMDB的所有方方面面,比如FMDB的executeUpdate:系列方法.数据库的加解密等等.这次写的就是对FMDatabase和FMResultSet这两个文件的补全内容.每次写这种补全的内容最头疼,内容会很分散,感觉没啥条理. 2. executeUpdate:系列函数 注意除了"SELECT"语句外,其他的

Spark2.1内部原理剖析与源码阅读、程序设计与企业级应用案例视频教程

38套大数据,云计算,架构,数据分析师,Hadoop,Spark,Storm,Kafka,人工智能,机器学习,深度学习,项目实战视频教程 视频课程包含: 38套大数据和人工智能精品高级课包含:大数据,云计算,架构,数据挖掘实战,实时推荐系统实战,电视收视率项目实战,实时流统计项目实战,离线电商分析项目实战,Spark大型项目实战用户分析,智能客户系统项目实战,Linux基础,Hadoop,Spark,Storm,Docker,Mapreduce,Kafka,Flume,OpenStack,Hiv

Redis源码阅读(一)事件机制

Redis源码阅读(一)事件机制 Redis作为一款NoSQL非关系内存数据库,具有很高的读写性能,且原生支持的数据类型丰富,被广泛的作为缓存.分布式数据库.消息队列等应用.此外Redis还有许多高可用特性,包括数据持久化,主从模式备份等等,可以满足对数据完整有一定要求的场景. 而且Redis的源码结构简单清晰,有大量材料可以参阅:通过阅读Redis源码,掌握一些常用技术在Redis中的实现,相信会对个人编程水平有很大帮助.这里记录下我阅读Redis源码的心得.从我自己比较关心的几个技术点出发,