Focal Loss 的理解

论文:《Focal Loss for Dense Object Detection》

Focal Loss 是何恺明设计的为了解决one-stage目标检测在训练阶段前景类和背景类极度不均衡(如1:1000)的场景的损失函数。它是由二分类交叉熵改造而来的。

标准交叉熵

其中,p是模型预测属于类别y=1的概率。为了方便标记,定义:

交叉熵CE重写为:

α-平衡交叉熵:

有一种解决类别不平衡的方法是引入一个值介于[0; 1]之间的权重因子α:当y=1时,取α; 当y=0时,取1-α。

这种方法,当y=0(即背景类)时,随着α的增大,会对损失进行很大惩罚(降低权重),从而减轻背景类

太多对训练的影响。

类似Pt,可将α-CE重写为:

Focal Loss定义

虽然α-CE起到了平衡正负样本的在损失函数值中的贡献,但是它没办法区分难易样本的样本对损失的贡献。因此就有了Focal Loss,定义如下:

其中,alpha和gamma均为常熟,是一个超参数。y‘为模型预测,其值介于(0-1)之间。

当y=1时,y‘->1,表示easy positive,它对权重的贡献->0;

当y=0是,y‘->0,表示easy negative,它对权重的贡献->0.

因此,Focal Loss不仅降低了背景类的权重,还降低了easy positive/negative的权重。

gamma是对损失函数的调节,当gamma=0是,Focal Loss与α-CE等价。以下是gamma

对Focal Loss的调节。

Focal Loss的Pytorch实现(蓝色字体)

以下Focal Loss=Focal Loss + Regress Loss;

代码来自:https://github.com/yhenon/pytorch-retinanet

  1 import numpy as np
  2 import torch
  3 import torch.nn as nn
  4
  5 def calc_iou(a, b):
  6     area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
  7
  8     iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
  9     ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
 10
 11     iw = torch.clamp(iw, min=0)
 12     ih = torch.clamp(ih, min=0)
 13
 14     ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
 15
 16     ua = torch.clamp(ua, min=1e-8)
 17
 18     intersection = iw * ih
 19
 20     IoU = intersection / ua
 21
 22     return IoU
 23
 24 class FocalLoss(nn.Module):
 25     #def __init__(self):
 26
 27     def forward(self, classifications, regressions, anchors, annotations):
 28         alpha = 0.25
 29         gamma = 2.0
 30         batch_size = classifications.shape[0]
 31         classification_losses = []
 32         regression_losses = []
 33
 34         anchor = anchors[0, :, :]
 35
 36         anchor_widths  = anchor[:, 2] - anchor[:, 0]
 37         anchor_heights = anchor[:, 3] - anchor[:, 1]
 38         anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
 39         anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights
 40
 41         for j in range(batch_size):
 42
 43             classification = classifications[j, :, :]
 44             regression = regressions[j, :, :]
 45
 46             bbox_annotation = annotations[j, :, :]
 47             bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
 48
 49             if bbox_annotation.shape[0] == 0:
 50                 regression_losses.append(torch.tensor(0).float().cuda())
 51                 classification_losses.append(torch.tensor(0).float().cuda())
 52
 53                 continue
 54
 55             classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
 56
 57             IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations
 58
 59             IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1
 60
 61             #import pdb
 62             #pdb.set_trace()
 63
 64             # compute the loss for classification
 65             targets = torch.ones(classification.shape) * -1
 66             targets = targets.cuda()
 67
 68             targets[torch.lt(IoU_max, 0.4), :] = 0
 69
 70             positive_indices = torch.ge(IoU_max, 0.5)
 71
 72             num_positive_anchors = positive_indices.sum()
 73
 74             assigned_annotations = bbox_annotation[IoU_argmax, :]
 75
 76             targets[positive_indices, :] = 0
 77             targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
 78
 79             alpha_factor = torch.ones(targets.shape).cuda() * alpha
 80
 81             alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
 82             focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
 83             focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
 84
 85             bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
 86
 87             # cls_loss = focal_weight * torch.pow(bce, gamma)
 88             cls_loss = focal_weight * bce
 89
 90             cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
 91
 92             classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
 93
 94             # compute the loss for regression
 95
 96             if positive_indices.sum() > 0:
 97                 assigned_annotations = assigned_annotations[positive_indices, :]
 98
 99                 anchor_widths_pi = anchor_widths[positive_indices]
100                 anchor_heights_pi = anchor_heights[positive_indices]
101                 anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
102                 anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
103
104                 gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
105                 gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
106                 gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
107                 gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights
108
109                 # clip widths to 1
110                 gt_widths  = torch.clamp(gt_widths, min=1)
111                 gt_heights = torch.clamp(gt_heights, min=1)
112
113                 targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
114                 targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
115                 targets_dw = torch.log(gt_widths / anchor_widths_pi)
116                 targets_dh = torch.log(gt_heights / anchor_heights_pi)
117
118                 targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
119                 targets = targets.t()
120
121                 targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
122
123
124                 negative_indices = 1 - positive_indices
125
126                 regression_diff = torch.abs(targets - regression[positive_indices, :])
127
128                 regression_loss = torch.where(
129                     torch.le(regression_diff, 1.0 / 9.0),
130                     0.5 * 9.0 * torch.pow(regression_diff, 2),
131                     regression_diff - 0.5 / 9.0
132                 )
133                 regression_losses.append(regression_loss.mean())
134             else:
135                 regression_losses.append(torch.tensor(0).float().cuda())
136
137 return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

原文地址:https://www.cnblogs.com/houjun/p/10220485.html

时间: 2024-10-09 19:31:36

Focal Loss 的理解的相关文章

[论文理解]Focal Loss for Dense Object Detection(Retina Net)

Focal Loss for Dense Object Detection Intro 这又是一篇与何凯明大神有关的作品,文章主要解决了one-stage网络识别率普遍低于two-stage网络的问题,其指出其根本原因是样本类别不均衡导致,一针见血,通过改变传统的loss(CE)变为focal loss,瞬间提升了one-stage网络的准确率.与此同时,为了测试该loss对网络改进的影响,文章还特地设计了一个网络,retina net,证明了其想法. Problems 为啥one-stage网

Focal Loss

转自:https://blog.csdn.net/u014380165/article/details/77019084 论文:Focal Loss for Dense Object Detection 论文链接:https://arxiv.org/abs/1708.02002 优化版的MXNet实现:https://github.com/miraclewkf/FocalLoss-MXNet RBG和Kaiming大神的新作. 我们知道object detection的算法主要可以分为两大类:t

Focal Loss for Dense Object Detection(RetinaNet)

Focal Loss for Dense Object Detection ICCV2017 RBG和Kaiming大神的新作. 论文目标 我们知道object detection的算法主要可以分为两大类:two-stage detector和one-stage detector.前者是指类似Faster RCNN,RFCN这样需要region proposal的检测算法,这类算法可以达到很高的准确率,但是速度较慢.虽然可以通过减少proposal的数量或降低输入图像的分辨率等方式达到提速,但是

focal loss for multi-class classification

转自:https://blog.csdn.net/Umi_you/article/details/80982190 Focal loss 出自何恺明团队Focal Loss for Dense Object Detection一文,用于解决分类问题中数据类别不平衡以及判别难易程度差别的问题.文章中因用于目标检测区分前景和背景的二分类问题,公式以二分类问题为例.项目需要,解决Focal loss在多分类上的实现,用此博客以记录过程中的疑惑.细节和个人理解,Keras实现代码链接放在最后. 框架:K

Focal Loss解读

Focal loss论文详解 链接:https://zhuanlan.zhihu.com/p/49981234 何恺明大神的「Focal Loss」,如何更好地理解? 链接:https://zhuanlan.zhihu.com/p/32423092 原文地址:https://www.cnblogs.com/kandid/p/11453572.html

处理样本不平衡的LOSS—Focal Loss

0 前言 Focal Loss是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的.在理解Focal Loss前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失.然后我们从样本权利的角度出发,理解Focal Loss是如何分配样本权重的.Focal是动词Focus的形容词形式,那么它究竟Focus在什么地方呢?详细的代码请看Gitee. 1 交叉熵 1.1 交叉熵损失(Cross Entropy Loss) 有\(N\)个样本,输入一个\(C\)分类器,得到的输出为\(X

Focal Loss for Dense Object Detection 论文阅读

何凯明大佬 ICCV 2017 best student paper 作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确率,同时不影响原有的速度.one-stage detector的准确率不如two-stage detector的原因,作者认为原因是:样本的类别不均衡导致的.因此针对类别不均衡问题,作者提出一种新的损失函数:focal loss,这个损失函数是在标准交叉熵损失基础上修改得到的.这个函数可以通过减少易

[论文理解] CornerNet: Detecting Objects as Paired Keypoints

[论文理解] CornerNet: Detecting Objects as Paired Keypoints 简介 首先这是一篇anchor free的文章,看了之后觉得方法挺好的,预测左上角和右下角,这样不需要去管anchor了,理论上也就w*h个点,这总比好几万甚至好几十万的anchor容易吧.文章灵感来源于Newell et al. (2017) on Associative Embedding in the context of multi-person pose estimation

【从零开始学习YOLOv3】8. YOLOv3中Loss部分计算

YOLOv1是一个anchor-free的,从YOLOv2开始引入了Anchor,在VOC2007数据集上将mAP提升了10个百分点.YOLOv3也继续使用了Anchor,本文主要讲ultralytics版YOLOv3的Loss部分的计算, 实际上这部分loss和原版差距非常大,并且可以通过arc指定loss的构建方式, 如果想看原版的loss可以在下方release的v6中下载源码. Github地址: https://github.com/ultralytics/yolov3 Github