分类问题损失函数的信息论解释

分类问题损失函数的信息论解释

分类问题的优化过程是一个损失函数最小化的过程,对应的损失函数一般称为logloss,对于一个多分类问题,其在N个样本上的logloss损失函数具有以下形式:

其中,yi(n)代表第n个样本是否属于第i个类别,取值为0或1,f(x(n))i代表分类模型对于第n个样本属于第i个类别的预测概率。将上面的式子稍作简化就可以得到我们常见的二分类问题下的损失函数,在这里不做展开,我们下面的讨论也都对于更为一般的多分类问题展开,而这些讨论对于二分类问题显然也同样适用。

上面的损失函数,很容易从最大似然的角度来做理解,也就是说等号右边的部分,去掉负号以后,对应着模型的一个估计f在N个样本上(取了log)的似然函数,而似然函数的最大化就对应着损失函数的最小化。

但是这个损失函数还有另外一个名字,叫做cross-entropy loss,从名字可以看出,这是一个信息论相关的名字,我们这篇文章就从信息论的角度,来理解一下分类问题的损失函数。

重新认识熵(entropy)

说起熵,大家都能知道衡量的是“数据的混乱程度”,但是它具体是如何衡量的呢?让我们首先来重新认识一下熵。

现在是周五的下班高峰期,你站在北京东三环的一座天桥上面,望着一辆辆汽车穿梭而过。你现在肩负着一个任务:你需要告诉我你看到的每一辆车的品牌型号,而我们的通讯工具,是一个二进制的通信管道,里面只能传输0或者1,这个管道的收费是1¥/bit。

显然你需要设计一组二进制串,每个串对应一个车型,例如1001对应的是一辆大众桑塔纳。那么你要如何设计这一组二进制串呢?具体来说,你会为丰田凯美瑞和特斯拉ModelS设计同样长度的串吗?

即使你不精通概率论,你可能也不会这么做,因为你知道大街上跑着的凯美瑞肯定比ModelS多得多,用同样长度的bit来传输肯定是不经济的。你肯定会为凯美瑞设计一个比较短的串,而为ModelS设计一个长一些的串。你为什么会这么做?本质上来讲,你是在利用你对分布的先验知识,来减少所需的bit数量。那具体我们应该如何利用分布知识来对信息进行编码呢?

幸运的是,香农(Shannon)老先生证明了,如果你知道一个变量的真实分布,那么为了使得你使用的平均bit最少,那么你应该给这个变量的第i个取值分配log1/yi个bit,其中yi是变量的第i个取值的概率。如果我们按照这样的方式来分配bit,那么我们就可以得到最优的数据传输方案,在这个方案下,我们为了传输这个分布产生的数据,平均使用的bit数量为:

有没有很眼熟?没错,这就是我们天天挂在嘴边的熵(entropy)。

交叉熵(cross entropy)

在上面的例子中,我们利用我们对数据分布y的了解,来设计数据传输方案,在这个方案中,数据的真实分布y充当了一个“工具”的角色,这个工具可以让我们的平均bit长度达到最小。

但是在大部分真实场景中,我们往往不知道真实y的分布,但是我们可以通过一些统计的方法得到y的一个估计。如果我们用来设计传输方案,也就是说,我们给分布的第i个取值分配log1/yi个bit,结果会是怎样?

套用之前的式子,将log中的y替换成为y^,我们可以得到,如果使用y^作为“工具”来对数据进行编码传输,能够使用的最小平均bit数为:

这个量,就是所谓的交叉熵(cross entropy),代表的就是使用y^来对y进行编码的话,需要使用的最短平均bit串长度。

交叉熵永远大于或等于熵,因为交叉熵是在用存在错误的信息来编码数据,所以一定会比使用正确的信息要使用更多的bit。只有当y^和y完全相等时,交叉熵才等于熵。

用交叉熵衡量分类模型质量

现在回到分类问题上来。假设我们通过训练得到了某模型,我们希望评估这个模型的好坏。从上面信道传输的角度来看,这个模型实际上提供了对真实分布y的一个估计y^。我们说要评估这个模型的好坏,实际是是想知道我们给出的估计y^和真实的分布y相差多大,那么我们可以使用交叉熵来度量这个差异。

由于交叉熵的物理意义是用y^作为工具来传输y平均需要多少个bit,那我们可以计算一下如果用y^来传输整个训练数据集需要多少个bit,首先我们看一下传输第n个样本需要多少个bit。由于估计出来的模型对于第n个样本属于第i个类的预测概率是y^i(n),而第n个样本的真实概率分布是yi(n),所以这一个样本也可以看做是一个概率分布,那么根据交叉熵的定义:

那么传输整个数据集需要的bit就是:

进一步变换可得:

其中。对比文章最开始的logloss损失函数可知:

也就是说,分类问题的损失函数,从信息论的角度来看,等价于训练出来的模型(分布)与真实模型(分布)之间的交叉熵(两者相差一个只和样本数据量有关的倍数N),而这个交叉熵的大小,衡量了训练模型与真实模型之间的差距,交叉熵越小,两者越接近,从而说明模型越准确。

总结

通过上面的讲解,我们从信息论的角度重新认识了分类问题的损失函数。信息论和机器学习是紧密相关的两个学科,从信息论的角度来说,模型优化的本质就是在减少数据中的信息,也就是不确定性。希望这个理解角度,能够让大家对分类问题有一个更全面的认识。希望对熵有进一步了解的同学,可以读一下香农老先生的著名文章《AMathematical Theory of Communication》,有时间的同学更可以研读一下《Elements ofInformation Theory》这本巨著,一定会让你的ML内功发生质的提升。

本文的灵感和部分内容来自这篇文章:http://rdipietro.github.io/friendly-intro-to-cross-entropy-loss/

本文作者:

张相於([email protected]),现任当当网推荐系统开发经理,负责当当网的推荐系统、NLP算法等工作。多年来主要从事推荐系统以及机器学习相关工作,也做过反垃圾、反作弊相关工作,并热衷于探索大数据技术&机器学习技术在其他领域的应用实践。

时间: 2024-10-12 21:16:57

分类问题损失函数的信息论解释的相关文章

David MacKay:用信息论解释 '快速排序'、'堆排序' 本质与差异

这篇文章是David MacKay利用信息论,来对快排.堆排的本质差异导致的性能差异进行的比较. 信息论是非常强大的,它并不只是一个用来分析理论最优决策的工具. 从信息论的角度来分析算法效率是一件很有趣的事,它给我们分析排序算法带来了一种新的思路. 运用了信息论的概念,我们很容易理解为什么快排的速度那么快,以及它的缺陷在哪里. 由于个人能力不足,对于本文的理解可能还是有点偏差. 而且因为翻译的困难,这篇译文有很多地方并没有翻译出来,还是使用了原文的句子. 所以建议大家还是阅读原文Heapsort

多分类SVM损失函数: Multiclass SVM loss

1. SVM 损失:在一个样本中,对于真实分类与其他每各个分类,如果真实分类所得的分数与其他各分类所得的分数差距大于或等于安全距离,则真实标签分类与该分类没有损失值:反之则需要计算真实分类与该分类的损失值: 真实分类与其他各分类的损失值的总和即为一个样本的损失值 ①即真实标签分类所得分数大于等于该分类的分数+安全距离,S_yi >=S_j + △,那么损失值=0 ②否则,损失值等于其他分类的分数 + 安全距离(阈值)- 真实标签分类所得的分数,即损失值=S_j + △ - S_yi S_yi:真

为什么如果待分类类别是10个,类别范围可以设置成(9,11)

为什么如果待分类类别是10个,类别范围可以设置成(9,11) 11的解释:有一个背景分类的过程 9的解释是:? 这个主要看损失函数设计对于类别的解释性 原文地址:https://www.cnblogs.com/miaozhijuan/p/12578515.html

Tensorflow 损失函数及学习率的四种改变形式

Reference: https://blog.csdn.net/marsjhao/article/details/72630147 分类问题损失函数-交叉熵(crossentropy) 交叉熵描述的是两个概率分布之间的距离,分类中广泛使用的损失函数,公式如下 在网络中可以通过Softmax回归将前向传播得到的结果变为交叉熵要求的概率分数值.Tensorflow中,Softmax回归的参数被去掉,通过一层将神经网络的输出变为一个概率分布. 代码实现 import tensorflow as tf

pytorch常用损失函数

损失函数的基本用法: criterion = LossCriterion() #构造函数有自己的参数 loss = criterion(x, y) #调用标准时也有参数 得到的loss结果已经对mini-batch数量取了平均值 1.BCELoss(二分类) CLASS torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean') 创建一个衡量目标和输出之间二进制交叉熵的criterion unre

做一个logitic分类之鸢尾花数据集的分类

做一个logitic分类之鸢尾花数据集的分类 Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例.数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度.花萼宽度.花瓣长度.花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种. 首先我们来加载一下数据集.同时大概的展示下数据结构和数据摘要. import numpy as np

python分类预测模型的特点

模型 模型特点 位于 SVM 强大的模型,可以用来回归,预测,分类等,而根据选取不同的和函数,模型可以是线性的/非线性的 sklearn.svm 决策树 基于"分类讨论,逐步细化"思想的分类模型,模型直观,易解释 sklearn.tree 朴素贝叶斯 基于概率思想的简单有效的分类模型,能够给出容易理解的概率解释 sklearn.naive_bayes 神经网络 具有强大的拟合能力,可疑用于拟合,分类等,它有多个增强版本,如递神经网络,卷积神经网络,自编吗器等,这些是深度学习的模型基础

图论基础知识(1)-什么是图和图的分类

emmm......蒟蒻的第一篇博客,先讲一个比较简单的东西来熟悉以下操作吧(还是怕自己翻车) 由于本人知识水平有限,暂时不会涉及相关数学知识,这篇博客主要还是提供个人对图论的比较感性的认识 这篇文章将要介绍: 图的基本定义 图的简单分类 一些简单术语的解释 因为本人比较蒻,所以这篇博客会讲的非常慢,dalao们可以绕步了... Part1 什么是图: 这是百度百科里面给出的解释,很重要的一条就是:“这种图形通常是用来描述某些事物之间的某种特定关系”,也就是说图实际上存储的是一些关系,所以说以下

cs231n Convolutional Neural Networks for Visual Recognition 2 SVM softmax

linear classification 上节中简单介绍了图像分类的概念,并且学习了费时费内存但是精度不高的knn法,本节我们将会进一步学习一种更好的方法,以后的章节中会慢慢引入神经网络和convolutional neural network.这种新的算法有两部分组成: 1. 评价函数score function,用于将原始数据映射到分类结果(预测值): 2. 损失函数loss function, 用于定量分析预测值与真实值的相似程度,损失函数越小,预测值越接近真实值. 我们将两者结合,损失