理解二分类交叉熵|可视化的方法解释对数损失

小白学视觉

共 3977字,需浏览 8分钟

 ·

2022-05-28 11:07

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

 

介绍

 如果你在训练一个二分类分类器,很有可能你在使用二值交叉熵,log损失,作为你的损失函数。

你有没有想过,使用这个损失函数到底意味着什么?事实是,现在的各种库和框架非常的简单易用,导致大家很容易忽视所使用的损失函数的真正意义。


动机


我一直在找一个可以通过可视化到的方法清楚而简单的解释二元交叉熵(log损失)的背后的真正含义,这样我可以在 Data Science Retreat上展示给我的学生,但是我一直没有找到。既然找不到我想要的,那我就自己来:-)

 

一个简单的分类问题


让我们从10个随机数开始:

  x = [-2.2, -1.4, -0.8, 0.2, 0.4, 0.8, 1.2, 2.2, 2.9, 4.6]

这就是我们唯一的特征:x

现在,我们给这些点涂上点颜色:红色和绿色,作为标签。

所以,我们的分类问题就很直观了:给定了特征x,需要我们预测标签:红色或者绿色。

既然是个二分类,我们可以将这个问题描述成:“这个点是绿色的吗?”,或者,“这个点是绿色的概率是多少?”,理想的状态下,绿色点的概率应该为1.0,同时红色点的概率应该为0.0。

在这样的设定下,绿色点属于正样本,红色点属于负样本。

如果我拟合一个模型来进行分类,预测每个点是绿色的概率。给定点的颜色,我们如何来评估这个预测的概率的好坏?这就是损失函数的目的!损失函数对于好的预测将返回一个低的值,对于坏的预测,将返回一个高的值。

对于二分类,比如我们的例子,典型的损失函数就是二值交叉熵(对数损失)。

 

损失函数:二元交叉熵/对数损失


如果你仔细看看这个损失函数,你会发现:

y是标签(1是绿色的,0是红色的),p(y)是所有的N个点预测是绿色的概率。

这个公式告诉你,对于每一个绿色(y=1)的点,加了一个log(p(y))到损失中,这就是绿色的对数概率。相反的,对于每一个红色(y=0)的点添加了log(1-p(y)),这个是红色的对数概率。一点也不难,也很不直观。

另外,熵和这些有个什么关系?为什么我们要首先取概率的对数?这才是有价值的问题,我希望在下面的 “Show me the math” 环节中回答。

但是,在我们开始更多的公式之前,我先给你展示一个上面公式的可视化的表示。

 

计算损失—可视化的方法


首先,我们根据类别将这些点分开,正样本和负样本,就像这样:

现在,我们来训练逻辑回归模型来分类我们的点。这个回归的拟合是一个sigmoid的曲线,表示了给定的x是绿色的概率。就像这样:

对于所有的属于正样本的点(绿色),我们的分类器给出的预测概率是什么?就是sigmoid曲线下面的绿色的条,x的坐标代表了这个点。

到现在为止,一切都好!那么负样本的点呢?记住,sigmoid曲线之下的绿条表示的该点是绿色的概率。那么,给定的点是红色的概率是多少呢?当然就是sigmoid曲线上面红色条啦 :-)

把这些放在一起,我们得到了这样的东西:

条子代表了每个点对应的类别的预测的概率。

好了,我们有了预测的概率,是时候计算一下二值交叉熵/对数损失来评估一下了。

这些概率就是我们需要的东西,所以,我们不需要x的坐标了,我们把竖条一个挨一个排列起来。

现在,这些竖条不再有什么含义了,我们改变一下位置:

既然我们是想计算损失,我们需要惩罚坏的预测,是吗?如果对应类别的相关的概率是1.0,我们需要对应的loss为零。对应的,如果概率很低,比如0.01,我们希望损失很大!

结果就是,将概率值取对数能够很好的满足我们的需求(实际上,使用对数的原因是来自于交叉熵的定义)。

下面的图显示的很清楚,预测为真的概率值越趋向于零,损失指数增加:

很公平!我们取概率的对数——这些就是每个点对应的损失。

最后,我们计算所有损失的均值。

好了!我们成功的计算了二元交叉熵/对数损失的值,是0.3329!


给我看代码


如果你需要重复确认一些我们的发现,运行下面的代码,自己看!

  from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import numpy as np

x = np.array([-2.2, -1.4, -.8, .2, .4, .8, 1.2, 2.2, 2.9, 4.6])
y = np.array([0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

logr = LogisticRegression(solver='lbfgs')
logr.fit(x.reshape(-1, 1), y)

y_pred = logr.predict_proba(x.reshape(-1, 1))[:, 1].ravel()
loss = log_loss(y, y_pred)

print('x = {}'.format(x))
print('y = {}'.format(y))
print('p(y) = {}'.format(np.round(y_pred, 2)))
print('Log Loss / Cross Entropy = {:.4f}'.format(loss))

 

给我看数学(你是认真的?!)


开个玩笑,上面的东西不是那么数学,如果你想理解熵,对数在这个里面扮演的角色,我们开始:-)

如果你想深入了解信息论,包括所有的概念——熵,交叉熵等等,可以看看Chris Olah’s写的的东西http://colah.github.io/posts/2015-09-Visual-Information/,非常的详细。

 

分布


我们从我们的数据分布开始。y代表了我们的点的类别(有3个红色点,7个绿色点),这就是分布,我们叫做q(y),看起来是这样的:

 


熵是一个给定的分布的不确定性的度量。

如果所有的点都是绿色的会怎么样?分布的不确定性是什么样的?零,对吗?毕竟,点的颜色是毫无疑问的,永远是绿色!所以,熵为零!

另外,如果我们知道正好一半的点是绿色而另外一半是红色呢?这就是最差的情况对吗?我们猜颜色的时候就没有任何的优势了:完全的随机!这种情况下,熵的值由下面的公式给出,我们的类别数是2:

对于任何一个之间的情况,我们可以计算熵的分布,就像我们的q(y),再使用下面的公式,C是类别的数量:

所以,如果我们知道了一个随机变量的真实的分布,我们就可以计算它的熵。但是,为什么一开始要训练个分类器呢?毕竟,我们知道真实的分布了啊

但是,如果我们不知道呢?我们是不是可以通过另外的分布比如说p(y)来估计真实的分布呢?当然可以!

 

交叉熵


我们假设我们的点服从另外的分布p(y),但是我们知道这个分布是来自于真实(未知)的分布q(y),是吗?

如果我们计算了熵,我们实际上计算的是这两个分布的交叉熵:

如果我们可以神奇的将p(y)和q( y)匹配的很好,那么交叉熵的计算值和熵的计算值也会匹配的很好。

既然这个是不太可能发生的,在真实的分布上,交叉熵永远会比熵要大那么一点。

原来,交叉熵和熵的差值是有个名字的...

 

KL散度


KL散度,衡量的是两个分布之间的差异性:

这个的意思是, p(y)和q(y)越接近,散度的值越小,交叉熵也是这样。

所以,我们需要找到一个好的p(y)来用,这就是我们的分类器做的事情,是吗?确实也是这样!寻找最近的p(y),就是最小化交叉熵。

 

损失函数


在训练中,分类器使用了N个点找那个的每一个来计算交叉熵的损失,有效的拟合出分布p(y)!既然每个点的概率都是1/N,交叉熵是这样的:

还记得上面的图6吗?我们需要在每个点对应的真实类别的概率上计算交叉熵。意思就是正样本使用绿色条,负样本使用红色条,数学上可以这样写:

最后一步是计算所有的点在两个类别上的平均值,正样本和负样本:

最后,再加上一点操作,我们使用任何一个点,不管是正样本还是负样本,都用同样的公式:

好了!我们回到了二元交叉熵/对数损失最初的公式:-)

 

最后的一点想法


我真的希望上面的内容可以给一些理所当然的概念一些不同的东西。我当然也希望可以展示给你关于机器学习和信息论是联系在一起的。

好消息! 

小白学视觉知识星球

开始面向外开放啦👇👇👇




下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 52
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报