caffe中HingeLossLayer层原理以及源码分析

输入:

bottom[0]: NxKx1x1维,N为样本个数,K为类别数。是预测值。

bottom[1]: Nx1x1x1维, N为样本个数,类别为K时,每个元素的取值范围为[0,1,2,…,K-1]。是groundTruth。

输出:

top[0]: 1x1x1x1维, 求得是hingeLoss。

关于HingeLoss:

p: 范数,默认是L1范数,可以在配置中设置为L1或者L2范数。

:指示函数,如果第n个样本的真实label为k,则为,否则为-1。

tnk: bottom[0]中第n个样本,第k维的预测值。

前向传播代码分析:

template
void HingeLossLayer::Forward_cpu(const vector*>& bottom,
    const vector*>& top) {
  const Dtype* bottom_data = bottom[0]->cpu_data();   //得到num个样本的dim个预测值
  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
  const Dtype* label = bottom[1]->cpu_data();                //得到num个样本的groundTruth
  int num = bottom[0]->num();
  int count = bottom[0]->count();
  int dim = count / num;
  caffe_copy(count, bottom_data, bottom_diff);
  for (int i = 0; i < num; ++i) {
    //label[i]中存储了第i个样本的真实class,取值范围[0,1,2,...,K-1]
    //此处将第i个样本的K维预测值的label[i]处乘以-1相当于计算
   //caffe中HingeLossLayer层原理以及源码分析
    bottom_diff[i * dim + static_cast(label[i])] *= -1;
  }
  for (int i = 0; i < num; ++i) {
    for (int j = 0; j < dim; ++j) {
      //计算 caffe中HingeLossLayer层原理以及源码分析,存入 bottom_diff,即bottom[0]->mutable_cpu_diff()中
      bottom_diff[i * dim + j] = std::max( Dtype(0), 1 + bottom_diff[i * dim + j]);
    }
  }
  Dtype* loss = top[0]->mutable_cpu_data();
  switch (this->layer_param_.hinge_loss_param().norm()) {
  case HingeLossParameter_Norm_L1:  //L1范数
    loss[0] = caffe_cpu_asum(count, bottom_diff) / num;
    break;
  case HingeLossParameter_Norm_L2: //L2范数
    loss[0] = caffe_cpu_dot(count, bottom_diff, bottom_diff) / num;
    break;
  default:
    LOG(FATAL) << "Unknown Norm";
  }
}

反向传播原理:

由于bottom[1]是groundtruth,不需要反传,只需要对bottom[0]进行反传,反传是损失E对t的偏导。

以L2范数为例,求偏导为:

其中:

反向传播源码分析:

template
void HingeLossLayer::Backward_cpu(const vector*>& top,
    const vector& propagate_down, const vector*>& bottom) {
  if (propagate_down[1]) {
    LOG(FATAL) << this->type()
               << " Layer cannot backpropagate to label inputs.";
  }
  if (propagate_down[0]) {
    Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); //说明中提到的hinge
    const Dtype* label = bottom[1]->cpu_data();
    int num = bottom[0]->num();
    int count = bottom[0]->count();
    int dim = count / num;
    for (int i = 0; i < num; ++i) {
      //相当于求hinge*偏hinge/偏tnk部分
      bottom_diff[i * dim + static_cast(label[i])] *= -1;
    }
    const Dtype loss_weight = top[0]->cpu_diff()[0];
    switch (this->layer_param_.hinge_loss_param().norm()) {
    case HingeLossParameter_Norm_L1:  //L1部分反传
      caffe_cpu_sign(count, bottom_diff, bottom_diff);  //L1求导的结果: 正返回1 负返回-1 0返回0
      caffe_scal(count, loss_weight / num, bottom_diff); //scale一下
      break;
    case HingeLossParameter_Norm_L2: //L2部分反传,就是scale一下
      caffe_scal(count, loss_weight * 2 / num, bottom_diff);
      break;
    default:
      LOG(FATAL) << "Unknown Norm";
    }
  }
} 

版权声明:本文为博主原创文章,未经博主允许不得转载。

时间: 2024-08-02 06:57:58

caffe中HingeLossLayer层原理以及源码分析的相关文章

【Spring】Spring&amp;WEB整合原理及源码分析

表现层和业务层整合: 1. Jsp/Servlet整合Spring: 2. Spring MVC整合SPring: 3. Struts2整合Spring: 本文主要介绍Jsp/Servlet整合Spring原理及源码分析. 一.整合过程 Spring&WEB整合,主要介绍的是Jsp/Servlet容器和Spring整合的过程,当然,这个过程是Spring MVC或Strugs2整合Spring的基础. Spring和Jsp/Servlet整合操作很简单,使用也很简单,按部就班花不到2分钟就搞定了

【OpenCV】SIFT原理与源码分析:关键点搜索与定位

<SIFT原理与源码分析>系列文章索引:http://www.cnblogs.com/tianyalu/p/5467813.html 由前一步<DoG尺度空间构造>,我们得到了DoG高斯差分金字塔: 如上图的金字塔,高斯尺度空间金字塔中每组有五层不同尺度图像,相邻两层相减得到四层DoG结果.关键点搜索就在这四层DoG图像上寻找局部极值点. DoG局部极值点 寻找DoG极值点时,每一个像素点和它所有的相邻点比较,当其大于(或小于)它的图像域和尺度域的所有相邻点时,即为极值点.如下图所

【OpenCV】SIFT原理与源码分析:DoG尺度空间构造

<SIFT原理与源码分析>系列文章索引:http://www.cnblogs.com/tianyalu/p/5467813.html 尺度空间理论 自然界中的物体随着观测尺度不同有不同的表现形态.例如我们形容建筑物用“米”,观测分子.原子等用“纳米”.更形象的例子比如Google地图,滑动鼠标轮可以改变观测地图的尺度,看到的地图绘制也不同:还有电影中的拉伸镜头等等…… 尺度空间中各尺度图像的模糊程度逐渐变大,能够模拟人在距离目标由近到远时目标在视网膜上的形成过程.尺度越大图像越模糊. 为什么要

ConcurrentHashMap实现原理及源码分析

ConcurrentHashMap实现原理 ConcurrentHashMap源码分析 总结 ConcurrentHashMap是Java并发包中提供的一个线程安全且高效的HashMap实现(若对HashMap的实现原理还不甚了解,可参考我的另一篇文章HashMap实现原理及源码分析),ConcurrentHashMap在并发编程的场景中使用频率非常之高,本文就来分析下ConcurrentHashMap的实现原理,并对其实现原理进行分析(JDK1.7). ConcurrentHashMap实现原

OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波

http://blog.csdn.net/chenyusiyuan/article/details/8710462 OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 2013-03-23 17:44 16963人阅读 评论(28) 收藏 举报 分类: 机器视觉(34) 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[+] KAZE系列笔记: OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 OpenCV学习笔记(28)KA

Cocos2d-X3.0 刨根问底(八)----- 场景(Scene)、层(Layer)相关源码分析

本章节我们重点分析Cocos2d-x3.0与 场景.层相关的源码.这部分源码集中在 libcocos2d –> layers_scenes_transitions_nodes目录下面 我先发个截图大家了解一下都有哪些文件.红色框里面的就是我们今天要分析的文件. 从命名上可以了解,这个文件夹里的文件主要包含了  场景,层,变换这三种类型的文件. 下面我们先分析Scene类 打开CCScene.h文件 /** @brief Scene is a subclass of Node that is us

【Spring】Spring&amp;WEB整合原理及源码分析(二)

一.整合过程 Spring&WEB整合,主要介绍的是Jsp/Servlet容器和Spring整合的过程,当然,这个过程是Spring MVC或Strugs2整合Spring的基础. Spring和Jsp/Servlet整合操作很简单,使用也很简单,按部就班花不到2分钟就搞定了,本节只讲操作不讲原理,更多细节.原理及源码分析后续过程陆续涉及. 1. 导入必须的jar包,本例spring-web-x.x.x.RELEASE.jar: 2. 配置web.xml,本例示例如下: <?xml vers

深度理解Android InstantRun原理以及源码分析

深度理解Android InstantRun原理以及源码分析 @Author 莫川 Instant Run官方介绍 简单介绍一下Instant Run,它是Android Studio2.0以后新增的一个运行机制,能够显著减少你第二次及以后的构建和部署时间.简单通俗的解释就是,当你在Android Studio中改了你的代码,Instant Run可以很快的让你看到你修改的效果.而在没有Instant Run之前,你的一个小小的修改,都肯能需要几十秒甚至更长的等待才能看到修改后的效果. 传统的代

【OpenCV】SIFT原理与源码分析:方向赋值

<SIFT原理与源码分析>系列文章索引:http://www.cnblogs.com/tianyalu/p/5467813.html 由前一篇<关键点搜索与定位>,我们已经找到了关键点.为了实现图像旋转不变性,需要根据检测到的关键点局部图像结构为特征点方向赋值.也就是在findScaleSpaceExtrema()函数里看到的alcOrientationHist()语句: // 计算梯度直方图 float omax = calcOrientationHist(gauss_pyr[o