SSD源码解读——损失函数的构建

之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html

为了加深对SSD的理解,因此对SSD的源码进行了复现,主要参考的github项目是ssd.pytorch

搭建SSD的项目,可以分成以下三个部分:

  1. 数据读取
  2. 网络搭建
  3. 损失函数的构建。

接下来,本篇博客重点分析损失函数的构建



检测任务的损失函数,与分类任务的损失函数具有很大不同。在检测的损失函数中,不仅需要计类别置信度的差异,坐标的差异,还需要使用到各种tricks,例如hard negative mining等。

在train.py中,首先需要对损失函数MultiBoxLoss()进行初始化,需要传入的参数为num_classes类别数,正例的IOU阈值和hard negative mining的正负样本比例。在论文中,VOC的类别总数是21(20个类别加上1个背景);当预测框与GT框的IOU大于0.5时,认为该预测框是正例;hard negative mining的正样本和负样本的比例是1:3。

    # 损失函数
    criterion = MultiBoxLoss(num_classes=voc[‘num_classes‘],
                             overlap_thresh=0.5,
                             neg_pos=3)

在models/multibox_loss中,定义了损失函数MultiBoxLoss()。在函数forward()中,需要传进来两个参数,分别是predictions和targets,其中,predictions是SSD网络得到的结果,分别是预测框坐标,类别置信度和先验锚点框;而targets是则是数据读取中的值,是GT框的坐标和类别label。首先,需要创建坐标和类别置信度的tensor,其shape分别是[batch_size,8732,4]和[batch_size,8732]。然后,使用一个for循环,将网络预测的结果和真实的坐标与label进行match,得到每个锚点框的label和坐标偏差,并将结果保存与loc_t和conf_t中。接下来,取出含目标的锚点框,得到其index,其中,pos的shape为[batch_size,8732],每个元素是true或者false。

class MultiBoxLoss(nn.Module):
    ‘‘‘
    SSD损失函数的计算
    ‘‘‘

    def __init__(self, num_classes, overlap_thresh, neg_pos):
        super(MultiBoxLoss, self).__init__()
        self.num_classes = num_classes  # 类别数
        self.threshold = overlap_thresh  # GT框与先验锚点框的阈值
        self.negpos_ratio = neg_pos  # 负例的比例

    def forward(self, predictions, targets):
        ‘‘‘
        对损失函数进行计算:
            1.进行GT框与先验锚点框的匹配,得到loc_t和conf_t,分别表示锚点框需要匹配的坐标和锚点框需要匹配的label
            2.对包含目标的先验锚点框loc_t(即正例)与预测的loc_data计算位置损失函数
            3.对负例(即背景)进行损失计算,选择损失最大的num_neg个负例和正例共同组成训练样本,取出这些训练样本的锚点框targets_weighted
                与置信度预测值conf_p,计算置信度损失:
                a)为Hard Negative Mining计算最大置信度loss_c
                b)将loss_c中正例对应的值置0,即保留了所有负例
                c)对此loss_c进行排序,得到损失最大的idx_rank
                d)计算用于训练的负例的个数num_neg,约为正例的3倍
                e)选择idx_rank中前num_neg个用作训练
                f)将正例的index和负例的index共同组成用于计算损失的index,并从预测置信度conf_data和真实置信度conf_t提出这些样本,形成
                    conf_p和targets_weighted,计算两者的置信度损失.
        :param predictions: 一个元祖,包含位置预测,置信度预测,先验锚点框
                    位置预测:(batch_size,num_priors,4),即[batch_size,8732,4]
                    置信度预测:(batch_size,num_priors,num_classes),即[batch_size, 8732, 21]
                    先验锚点框:(num_priors,4),即[8732, 4]
        :param targets: 真实框的坐标与label,[batch_size,num_objs,5]
                    其中,5代表[xmin,ymin,xmia,ymax,label]
        ‘‘‘
        loc_data, conf_data, priors = predictions
        num = loc_data.shape[0]  # 即batch_size大小
        priors = priors[:loc_data.shape[1], :]  # 取出8732个锚点框,与位置预测的锚点框数量相同
        num_priors = priors.shape[0]  # 8732

        loc_t = torch.Tensor(num, num_priors, 4)  # [batch_size,8732,4],生成随机tensor,后续用于填充
        conf_t = torch.Tensor(num, num_priors)  # [batch_size,8732]
        # 取消梯度更新,貌似默认是False
        loc_t.requires_grad = False
        conf_t.requires_grad = False

        for idx in range(num):
            truths = targets[idx][:, :-1]  # 坐标值,[xmin,ymin,xmia,ymax]
            labels = targets[idx][:, -1]  # label
            defaults = priors.cuda()
            match(self.threshold, truths, defaults, labels, loc_t, conf_t, idx)
        if torch.cuda.is_available():
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()  # shape:[batch_size,8732],其元素组成是类别标签号和背景

        pos = conf_t > 0  # 排除label=0,即排除背景,shape[batch_size,8732],其元素组成是true或者false
        # Localization Loss (Smooth L1),定位损失函数
        # Shape: [batch,num_priors,4]
        # pos.dim()表示pos有多少维,应该是一个定值(2)
        # pos由[batch_size,8732]变成[batch_size,8732,1],然后展开成[batch_size,8732,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)  # [num_pos,4],取出带目标的这些框
        loc_t = loc_t[pos_idx].view(-1, 4)  # [num_pos,4]
        # 位置损失函数
        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction=‘sum‘)  # 这里对损失值是相加,有公式可知,还没到相除的地步

        # 为Hard Negative Mining计算max conf across batch
        batch_conf = conf_data.view(-1, self.num_classes)  # shape[batch_size*8732,21]
        # gather函数的作用是沿着定轴dim(1),按照Index(conf_t.view(-1, 1))取出元素
        # batch_conf.gather(1, conf_t.view(-1, 1))的shape[8732,1],作用是得到每个锚点框在匹配GT框后的label
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1).long())  # 这个不是最终的置信度损失函数

        # Hard Negative Mining
        # 由于正例与负例的数据不均衡,因此不是所有负例都用于训练
        loss_c[pos.view(-1, 1)] = 0  # pos与loss_c维度不一样,所以需要转换一下,选出负例
        loss_c = loss_c.view(num, -1)  # [batch_size,8732]
        _, loss_idx = loss_c.sort(1, descending=True)  # 得到降序排列的index
        _, idx_rank = loss_idx.sort(1)

        num_pos = pos.sum(1, keepdim=True)  # pos里面是true或者false,因此sum后的结果应该是包含的目标数量
        num_neg = torch.clamp(self.negpos_ratio * num_pos, max=pos.size(1) - 1)  # 生成一个随机数用于表示负例的数量,正例和负例的比例约3:1
        neg = idx_rank < num_neg.expand_as(idx_rank)  # [batch_size,8732] 选择num_neg个负例,其元素组成是true或者false

        # 置信度损失,包括正例和负例
        # [batch_size, 8732, 21],元素组成是true或者false,但true代表着存在目标,其对应的index为label
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        # pos_idx由true和false组成,表示选择出来的正例,neg_idx同理
        # (pos_idx + neg_idx)表示选择出来用于训练的样例,包含正例和反例
        # torch.gt(other)函数的作用是逐个元素与other进行大小比较,大于则为true,否则为false
        # 因此conf_data[(pos_idx + neg_idx).gt(0)]得到了所有用于训练的样例
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted.long(), reduction=‘sum‘)

        # L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        N = num_pos.sum()  # 一个batch里面所有正例的数量
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c

在计算损失函数时,提及了函数match(),这个函数位于models/box_utils.py中,是一个非常关键的函数,对应论文的匹配策略那一章节,其作用是为每个锚点框指定GT框和为每个GT框指定锚点框。需要传进来几个参数,truths是GT框的坐标,priors是先验锚点框的坐标[中心点x,中心点y,W,H],labels是GT框对应的类别(不包含背景),loc_t和conf_t是用来保存结果的,idx是第i张图片。

为了方便表述,num_objects表示一张图中,GT框的数量;num_priors表示先验锚点框的数量,即8732。

第一步,由于先验锚点框priors的坐标形式是[中心点x,中心点y,W,H],需要使用函数point_from()来将其转化成[x_min,y_min,x_max,y_max]。然后计算每个GT框与所有先验锚点框的jaccard值,即IOU的值,使用了numpy风格的计算方式,返回的变量overlaps的shape为[GT框数量,8732]。

第二步,根据论文,为每个GT框匹配一个最大IOU的先验锚点框,确保每个GT框至少有一个锚点框进行预测。

第三步,为每个锚点框匹配上一个最大IOU的GT框来进行预测。

第四步,变量best_truth_overlap保存着每个框与GT框的最大IOU值(第三步的结果),使用index_fill()函数,将第二步的结果同步到这个变量中。在index_fill()函数中,使用数值2来进行填充,是为了确保第二步中得到的锚点框肯定会被选到。对变量best_truth_idx也进行同样的处理。

第五步,由于传入进来的labels的类别是从0开始的,SSD中认为0应该是背景,所以,需要对labels进行加一。这里需要注意一下,best_truth_idx的shape是[8732],每个元素的范围为[0,num_objects],所以conf的shape为[num_priors],每个元素表示先验锚点框的label(0是背景)。同时,需要将变量best_truth_overlap中IOU小于阈值(0.5)的锚点框的label设置为0。并将结果保存与conf_t,返回给外面的函数用于计算。

第六步,同样需要将GT框的坐标进行扩展,形成shape为[num_priors,4]的matches,这样每个锚点框都有对应的坐标进行预测,但最终并不是每个锚点框都用于训练中。

第七步,使用GT框与锚点框进行编码,对应论文中的公式2,得到shape为[num_priors,4]的值,即偏差,将此结果返回出去。

def match(threshold, truths, priors, labels, loc_t, conf_t, idx):
    ‘‘‘
    这个函数对应论文中的matching strategy匹配策略.SSD需要为每一个先验锚点框都指定一个label,
    这个label或者指向背景,或者指向每个类别.
    论文中的匹配策略是:
        1.首先,每个GT框选择与其IOU最大的一个锚点框,并令这个锚点框的label等于这个GT框的label
        2.然后,当锚点框与GT框的IOU大于阈值(0.5)时,同样令这个锚点框的label等于这个GT框的label
    因此,代码上的逻辑为:
        1.计算每个GT框与每个锚点框的IOU,得到一个shape为[num_object,num_priors]的矩阵overlaps
        2.选择与GT框的IOU最大的锚点框,锚点框的index为best_prior_idx,对应的IOU值为best_prior_overlap
        3.为每一个锚点框选择一个IOU最大的GT框,可能会出现多个锚点框匹配一个GT框的情况,此时,每个锚点框对应GT框的index为best_truth_idx,
            对应的IOU为best_truth_overlap.注意,此时IOU值可能会存在小于阈值的情况.
        4.第3步可能到导致存在GT框没有与锚点框匹配上的情况,所以要和第2步进行结合.在第3步的基础上,对best_truth_overlap进行选择,选择出
            best_prior_idx这些锚点框,让其对其的IOU等于一个大于1的定值;并且让best_truth_idx中index为best_prior_idx的锚点框的label
            与GT框对应上.最终,best_truth_overlap表示每个锚点框与GT框的最大IOU值,而best_truth_idx表示每个锚点框用于与相应的GT框进行
            匹配.
        5.第4步中,会存在IOU小于阈值的情况,要将这些小于IOU阈值的锚点框的label指向背景,完成第二条匹配策略.
            labels表示GT框对应的标签号,"conf=labels[best_truth_idx]+1"得到每个锚点框对应的标签号,其中label=0是背景.
            "conf[best_truth_overlap < threshold] = 0"则将小于IOU阈值的锚点框的label指向背景
        6.得到的conf表示每个锚点框对应的label,还需要一个矩阵,来表示每个锚点框需要匹配GT框的坐标.
            truths表示GT框的坐标,"matches = truths[best_truth_idx]"得到每个锚点框需要匹配GT框的坐标.
    :param threshold:IOU的阈值
    :param truths:GT框的坐标,shape:[num_obj,4]
    :param priors:先验锚点框的坐标,shape:[num_priors,4],num_priors=8732
    :param labels:这些GT框对应的label,shape:[num_obj],此时label=0还不是背景
    :param loc_t:坐标结果会保存在这个tensor
    :param conf_t:置信度结果会保存在这个tensor
    :param idx:结果保存的idx
    ‘‘‘
    # 第1步,计算IOU
    overlaps = jaccard(truths, point_from(priors))  # shape:[num_object,num_priors]

    # 第2步,为每个真实框匹配一个IOU最大的锚点框,GT框->锚点框
    # best_prior_overlap为每个真实框的最大IOU值,shape[num_objects,1]
    # best_prior_idx为对应的最大IOU的先验锚点框的Index,其元素值的范围为[0,num_priors]
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)

    # 第3步,若先验锚点框与GT框的IOU>阈值,也将这些锚点框匹配上,锚点框->GT框
    # best_truth_overlap为每个先验锚点框对应其中一个真实框的最大IOU,shape[1,num_priors]
    # best_truth_idx为每个先验锚点框对应的真实框的index,其元素值的范围为[0,num_objects]
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)

    best_prior_idx.squeeze_(1)  # [num_objects]
    best_prior_overlap.squeeze_(1)  # [num_objects]
    best_truth_idx.squeeze_(0)  # [num_priors],8732
    best_truth_overlap.squeeze_(0)  # [num_priors],8732

    # 第4步
    # index_fill_(self, dim: _int, index: Tensor, value: Number)对第dim行的index使用value进行填充
    # best_truth_overlap为第一步匹配的结果,需要使用到,使用best_prior_idx是第二步的结果,也是需要使用上的
    # 所以在best_truth_overlap上进行填充,表明选出来的正例
    # 使用2进行填充,是因为,IOU值的范围是[0,1],只要使用大于1的值填充,就表明肯定能被选出来
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # 确定最佳先验锚点框
    # 确保每个GT框都能匹配上最大IOU的先验锚点框
    # 得到每个先验锚点框都能有一个匹配上的数字
    # best_prior_idx的元素值的范围是[0,num_priors],长度为num_objects
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j

    # 第5步
    conf = labels[best_truth_idx] + 1  # Shape: [num_priors],0为背景,所以其余编号+1
    conf[best_truth_overlap < threshold] = 0  # 置信度小于阈值的label设置为0

    # 第6步
    matches = truths[best_truth_idx]  # 取出最佳匹配的GT框,Shape: [num_priors,4]

    # 进行位置编码
    loc = encode(matches, priors,voc[‘variance‘])
    loc_t[idx] = loc  # [num_priors,4],应该学习的编码偏差
    conf_t[idx] = conf  # [num_priors],每个锚点框的label

原文地址:https://www.cnblogs.com/dengshunge/p/11965444.html

时间: 2024-10-19 21:02:42

SSD源码解读——损失函数的构建的相关文章

ssd源码解读(caffe)

ssd是经典的one-stage目标检测算法,作者是基于caffe来实现的,这需要加入新的层来完成功能,caffe自定义层可以使用python和c++,faster rcnn既使用了c++定义如smoothl1layer,又使用了python定义,如proposaltargetlayer.roidatalayer等.而ssd完全使用c++来定义层,包括: 1)annotateddatalayer数据读取层,用于读取图像和标签数据,并且支持数据增强 2)permutelayer用于改变blob的读

Gson 源码解读

开源库地址:https://github.com/google/gson 解读版本:2.7 Gson是一个可以用来将Java对象转换为JSON字符串的Java库.当然,它也可以把JSON字符串转换为等价的Java对象.网上已经有了不少可将Java对象转换成JSON的开源项目.但是,大多数都要求你在Java类中加入注解,如果你无法修改源码的话就比较坑爹了,此外大多数开源库并没有对泛型提供完全的支持.于是,Gson在这两个重要的设计目标下诞生了.Gson可以作用于任意的Java对象(包括接触不到源码

【Spark】SparkContext源码解读

SparkContext的初始化 SparkContext是应用启动时创建的Spark上下文对象,是进行Spark应用开发的主要接口,是Spark上层应用与底层实现的中转站(SparkContext负责给executors发送task). SparkContext在初始化过程中,主要涉及一下内容: SparkEnv DAGScheduler TaskScheduler SchedulerBackend SparkUI 生成SparkConf SparkContext的构造函数中最重要的入参是Sp

Retrofit2 源码解读

开源库地址:https://github.com/square/retrofit 解读版本:2.1.0 基本概念 Retrofit 是一个针对Java/Android类型安全的Http请求客户端. 基本使用如下: 首先定义一个接口,抽象方法的返回值必须为Call<XX>. public interface GitHubService { @GET("users/{user}/repos") Call<List<Repo>> listRepos(@Pa

源码解读Mybatis List列表In查询实现的注意事项

转自:http://www.blogjava.net/xmatthew/archive/2011/08/31/355879.html 源码解读Mybatis List列表In查询实现的注意事项 在SQL开发过程中,动态构建In集合条件查询是比较常见的用法,在Mybatis中提供了foreach功能,该功能比较强大,它允许你指定一个集合,声明集合项和索引变量,它们可以用在元素体内.它也允许你指定开放和关闭的字符串,在迭代之间放置分隔符.这个元素是很智能的,它不会偶然地附加多余的分隔符.下面是一个演

OpenCV2马拉松第27圈——SIFT论文,原理及源码解读

计算机视觉讨论群162501053 转载请注明:http://blog.csdn.net/abcd1992719g/article/details/28913101 简介 SIFT特征描述子是David G. Lowe 在2004年的ijcv会议上发表的论文中提出来的,论文名为<<Distinctive Image Featuresfrom Scale-Invariant Keypoints>>.这是一个很强大的算法,主要用于图像配准和物体识别等领域,但是其计算量相比也比较大,性价

Android-Universal-Image-Loader 源码解读

Universal-Image-Loader是一个强大而又灵活的用于加载.缓存.显示图片的Android库.它提供了大量的配置选项,使用起来非常方便. 基本概念 基本使用 首次配置 在第一次使用ImageLoader时,必须初始化一个全局配置,一般会选择在Application中配置. public class MyApplication extends Application { @Override public void onCreate() { super.onCreate(); //为I

jdk1.8.0_45源码解读——Map接口和AbstractMap抽象类的实现

jdk1.8.0_45源码解读——Map接口和AbstractMap抽象类的实现 一. Map架构 如上图:(01) Map 是映射接口,Map中存储的内容是键值对(key-value).(02) AbstractMap 是继承于Map的抽象类,它实现了Map中的大部分API.其它Map的实现类可以通过继承AbstractMap来减少重复编码.(03) SortedMap 是继承于Map的接口.SortedMap中的内容是排序的键值对,排序的方法是通过比较器(Comparator).(04) N

jdk1.8.0_45源码解读——LinkedList的实现

jdk1.8.0_45源码解读——LinkedList的实现 一.LinkedList概述 LinkedList是List和Deque接口的双向链表的实现.实现了所有可选列表操作,并允许包括null值.    LinkedList既然是通过双向链表去实现的,那么它可以被当作堆栈.队列或双端队列进行操作.并且其顺序访问非常高效,而随机访问效率比较低. 注意,此实现不是同步的. 如果多个线程同时访问一个LinkedList实例,而其中至少一个线程从结构上修改了列表,那么它必须保持外部同步.这通常是通