centerloss损失函数的理解与实现

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as TF
import torchvision.utils as vutils
import torch.nn.functional as F
from torch.autograd import Function

class CenterLoss(nn.Module):
    """
    paper: http://ydwen.github.io/papers/WenECCV16.pdf
    code:  https://github.com/pangyupo/mxnet_center_loss
    pytorch code: https://blog.csdn.net/sinat_37787331/article/details/80296964
    """

    def __init__(self, features_dim, num_class=10, alpha=0.01, scale=1.0, batch_size=64):
        """
        初始化
        :param features_dim: 特征维度 = c*h*w
        :param num_class: 类别数量
        :param alpha:   centerloss的权重系数 [0,1]
        """
        assert 0 <= alpha <= 1
        super(CenterLoss, self).__init__()
        self.alpha = alpha
        self.num_class = num_class
        self.scale = scale
        self.batch_size = batch_size
        self.feat_dim = features_dim
        # store the center of each class , should be ( num_class, features_dim)
        self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim]))

        self.lossfunc = CenterLossFunc.apply
        init_weight(self, ‘normal‘)

    def forward(self, output_features, y_truth):
        """
        损失计算
        :param output_features: conv层输出的特征,  [b,c,h,w]
        :param y_truth:  标签值  [b,]
        :return:
        """
        batch_size = y_truth.size(0)
        output_features = output_features.view(batch_size, -1)
        assert output_features.size(-1) == self.feat_dim
        loss = self.lossfunc(output_features, y_truth, self.feature_centers)
        loss /= batch_size

        # centers_pred = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]
        # diff = output_features - centers_pred
        # loss = self.alpha * 1 / 2.0 * (diff.pow(2).sum()) / self.batch_size
        return loss

class CenterLossFunc(Function):
    # https://blog.csdn.net/xiewenbo/article/details/89286462
    @staticmethod
    def forward(ctx, feat, labels, centers):
        ctx.save_for_backward(feat, labels, centers)
        centers_batch = centers.index_select(0, labels.long())
        return (feat - centers_batch).pow(2).sum() / 2.0

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new(centers.size(0)).fill_(1)
        ones = centers.new(label.size(0)).fill_(1)
        grad_centers = centers.new(centers.size()).fill_(0)

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers / counts.view(-1, 1)
        return - grad_output * diff, None, grad_centers

if __name__ == ‘__main__‘:
    ct = CenterLoss(2, 10, 0.1).cuda()
    y = torch.Tensor([0, 0, 2, 1]).cuda()
    feat = torch.zeros(4, 2).cuda().requires_grad_()
    print(list(ct.parameters()))
    print(ct.feature_centers.grad)
    out = ct(feat, y)
    print(out.item())
    out.backward()
    print(ct.feature_centers.grad)
    print(feat.grad)

  

原文地址:https://www.cnblogs.com/dxscode/p/12059548.html

时间: 2024-10-13 21:48:36

centerloss损失函数的理解与实现的相关文章

GloVe损失函数的理解

简介 GloVe是一种非常简单快速的训练词向量的算法.与复杂的word2vec相比,其是一个log双线性模型,仅通过一个简单的损失函数就能够得到很好的结果. (1)J=∑i,jNf(Xi,j)(viTvj+bi+bj−log(Xi,j))2 其中,vi和vj是i和j的词向量,bi和bj是两个偏差项,f是一个权重函数,N为词汇表大小 但是这个损失函数的意义却不是很直观,这里参照一篇博客写了一下对于这个损失函数的分析 思路 Glove首先会通过设置定义的窗口大小,进行统计得到词的共现矩阵.如Xi,j

14-立刻、马上数据挖掘,生活就是这么刺激

记得群主在青葱的大学岁月,经常从图书馆贪婪地借书.我不喜欢在冬天或夏天去图书馆蹭空调自习,觉得太舒服了(事实是不喜欢扎堆排队),而喜欢在破旧又有年代感的自习室里蒸着桑拿或瑟瑟发抖学着习.没错,就是这么自虐.说到图书馆,暑假是可以借十本书的.我经常为这十本书斟酌一下午.记得一次我拿了<居里夫人自传>,还有<C++ programming>之类的英文原版装逼书.走过计算机类书架,我无法不注意到其中竟默默地躺着一些<数据仓库>.<数据挖掘>之类不知所云的书.我的第

[笔记]Logistic Regression理论总结

简述: 1. LR 本质上是对正例负例的对数几率做线性回归,因为对数几率叫做logit,做的操作是线性回归,所以该模型叫做Logistic Regression. 2. LR 的输出可以看做是一种可能性,输出越大则为正例的可能性越大,但是这个概率不是正例的概率,是正例负例的对数几率. 3. LR的label并不一定要是0和1,也可以是-1和1,或者其他,只是一个标识,标识负例和正例. 4. Linear Regression和Logistic Regression的区别: 这主要是由于线性回归在

CS231N-线性回归+svm多分类+softmax多分类

CS231N-线性回归+svm多分类+softmax多分类 计算机视觉 这一部分比较基础,没有太多视觉相关的.. 1.线性回归 假定在著名的 CIFAR10数据集上,包含10类数据.每类数据有10000条? 目标是输入一个图片,通过模型给出一个label.线性回归的思想就是 得到到F(x)作为某个类别的分数.那么针对每个可能的label都经过一个线性函数输出一个分值,那么我们选最大的其实就是最有可能的分数. 为什么这么做是合理的? 角度1: 每个种类一个 template,每个线性函数的W的训练

目标检测Anchor-free分支:基于关键点的目标检测

目标检测Anchor-free分支:基于关键点的目标检测(最新网络全面超越YOLOv3) https://blog.csdn.net/qiu931110/article/details/89430747 目标检测领域最近有个较新的方向:基于关键点进行目标物体检测.该策略的代表算法为:CornerNet和CenterNet.由于本人工作特性,对网络的实时性要求比较高,因此多用YoLov3及其变体.而就在今天下午得知,基于CornerNet改进的CornerNet-Squeeze网络居然在实时性和精

2.2 logistic回归损失函数(非常重要,深入理解)

上一节当中,为了能够训练logistic回归模型的参数w和b,需要定义一个成本函数 使用logistic回归训练的成本函数 为了让模型通过学习来调整参数,要给出一个含有m和训练样本的训练集 很自然的,希望通过训练集找到参数w和b,来得到自己得输出 对训练集当中的值进行预测,将他写成y^(I)我们希望他会接近于训练集当中的y^(i)的数值 现在来看一下损失函数或者叫做误差函数 他们可以用来衡量算法的运行情况 可以定义损失函数为y^和y的差,或者他们差的平方的一半,结果表明你可能这样做,但是实际当中

[转] 理解交叉熵在损失函数中的意义

转自:https://blog.csdn.net/tsyccnh/article/details/79163834 关于交叉熵在loss函数中使用的理解交叉熵(cross entropy)是深度学习中常用的一个概念,一般用来求目标与预测值之间的差距.以前做一些分类问题的时候,没有过多的注意,直接调用现成的库,用起来也比较方便.最近开始研究起对抗生成网络(GANs),用到了交叉熵,发现自己对交叉熵的理解有些模糊,不够深入.遂花了几天的时间从头梳理了一下相关知识点,才算透彻的理解了,特地记录下来,以

用Keras搞一个阅读理解机器人

catalogue 1. 训练集 2. 数据预处理 3. 神经网络模型设计(对话集 <-> 问题集) 4. 神经网络模型设计(问题集 <-> 回答集) 5. RNN神经网络 6. 训练 7. 效果验证 1. 训练集 1 Mary moved to the bathroom. 2 John went to the hallway. 3 Where is Mary? bathroom 1 4 Daniel went back to the hallway. 5 Sandra moved

GBRT 要点理解

1. 首先要理解Boost和Gradient Boost. 前者是在算法开始时候,,为每一个样本赋上一个相等的权重值,也就是说,最开始的时候,大家都是一样重要的.在每一次训练中得到的模型,会使得数据点的估计有所差异,所以在每一步结束后,我们需要对权重值进行处理,而处理的方式就是通过增加错分类点的权重,这样使得某些点如果老是被分错,那么就会被"严重关注",也就被赋上一个很高的权重.然后等进行了N次迭代(由用户指定),将会得到N个简单的基分类器(basic learner),最后将它们组合