Pytorch交叉熵损失(CrossEntropyLoss)函数内部运算解析_crossentropyloss(reduction="mean")-程序员宅基地

技术标签: python  深度学习  pytorch  

  对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。

1. 函数介绍

  Pytorch官网中对交叉熵损失函数的介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)

  该损失函数计算输入值(input)和目标值(target)之间的交叉熵损失。交叉熵损失函数可用于训练一个 C C C类别的分类问题。参数weight给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
  函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为 ( C ) (C) C的张量, ( m i n i b a t c h , C ) (minibatch,C) minibatchC ( m i n i b a t c h , C , d 1 , d 2 , . . . , d K ) (minibatch,C,d_1 ,d_2 ,... ,d_K) minibatchCd1d2...dK,在K维情况下, K ≥ 1 K \geq1 K1
  函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引
   本文以下内容均为target为类索引的情况。

  函数目标值(target)取值为在 [ 0 , C ) [0,C) [0C)之间的类索引, C C C为类别数。参数reduction设为'none'时,交叉熵损失可描述如下:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n l o g e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) ⋅ 1 { y n   / = i g n o r e _ i n d e x } (1) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T, \\ \large l_n = -w_{yn}log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}\cdot 1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}\tag{1} l(x,y)=L={ l1,...,lN}T,ln=wynlogc=1Cexp(xn,c)exp(xn,yn)1{ yn/=ignore_index}(1)

  其中, x x x是输入, y y y是目标值, w w w是weight, C C C是类别数, N N N为batch size。在reduction不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (2) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{2} l(x,y)=n=1Nn=1Nwyn1{ yn/=ignore_index}1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(2)

 需要指出的是,在这种情况下的交叉熵损失等价于LogSoftmaxNLLLoss的组合。1

  因此,我们可以从LogSoftmaxNLLLoss来深度解析交叉熵损失函数的内部运算。

2. LogSoftmax函数

  LogSoftmax()函数2公式如下:
L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) (3) LogSoftmax(x_i) = log(\frac{exp(x_i)}{\sum_{j}exp(x_j)}) \tag{3} LogSoftmax(xi)=log(jexp(xj)exp(xi))(3)
  即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的 log ⁡ e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) \textcolor{red}{\log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}} logc=1Cexp(xn,c)exp(xn,yn)

  代码示例如下:

>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
        [7., 3., 8.],
        [9., 7., 5.]])
 
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
        [0.2676, 0.0049, 0.7275],
        [0.8668, 0.1173, 0.0159]]) 
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])
        
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])

3. NLLLoss函数

  Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction为默认值'mean',参数weight为默认值'none'时)。

  官网介绍如下:

CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)

  参数reduction值为'none'时:
l ( x , y ) = L = { l 1 , . . . , l N } T ,   l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c   / = i g n o r e _ i n d e x } , (4) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T,\ l_n = -w_{yn}x_{n,yn}, w_c = weight[c]\cdot1\left \{ c\mathrlap{\,/}{=}ignore\_index\right \},\tag{4} l(x,y)=L={ l1,...,lN}T, ln=wynxn,yn,wc=weight[c]1{ c/=ignore_index},(4)
  其中, x x x为输入, y y y为目标值, w w w为weight, N N N为batch size。
  参数reduction值不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (5) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{5} l(x,y)=n=1Nn=1Nwyn1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(5)
  可以看出,当reduction'mean'时,即是对 l n l_n ln求加权平均值。weight参数默认为1,因此默认情况下,即是对 l n l_n ln求平均值。又 l n = − w y n x n , y n l_n = -w_{yn}x_{n,yn} ln=wynxn,yn,所以weight为默认值1时, l n = − x n , y n l_n=-x_{n,yn} ln=xn,yn。故此时,即是 x x x求平均后取负。 这部分对于公式(2)中的 ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n \textcolor{red}{\sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n} n=1Nn=1Nwyn1{ yn/=ignore_index}1ln

  实例代码验证如下:

>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
        [0.4952, 1.5870, 0.9594],
        [1.4170, 0.4525, 0.2519]])
        
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)

  平均取负有: V a l u e = − 1 3 ( 2.3858 + 0.4952 + 0.2519 ) = − 1.0443 Value = -\frac{1}{3}\left ( 2.3858+0.4952+0.2519 \right ) =-1.0443 Value=31(2.3858+0.4952+0.2519)=1.0443
  显然,平均取负结果和NLLLoss运算结果相同。

注:笔者窃以为,公式(5)中上式可写为 ∑ n = 1 N l n ∑ n = 1 N w y n \frac{\sum_{n=1}^{N}l_n}{\sum_{n=1}^{N}w_{yn}} n=1Nwynn=1Nln,如此则更容易理解。公式(2)同理。

4. 小结

  本文通过将CrossEntropyLoss拆解为LogSoftmaxNLLLoss两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。


  学习总结,以作分享,如有不妥,敬请指出。


Reference


  1. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss

  2. https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax

  3. https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XIAOSHUCONG/article/details/123527257

智能推荐

BAT批处理创建文件桌面快捷方式_批处理创建桌面快捷方式-程序员宅基地

文章浏览阅读1.5w次,点赞9次,收藏26次。简介一个创建某个文件到桌面快捷方式的BAT批处理.代码@echooff::设置程序或文件的完整路径(必选)setProgram=D:\Program Files (x86)\格式工厂.4.2.0\FormatFactory.exe::设置快捷方式名称(必选)setLnkName=格式工厂v4.2.0::设置程序的工作路径,一般为程序主目录,此项若留空,脚本将..._批处理创建桌面快捷方式

射频识别技术漫谈(6-10)_芯片 ttf模式-程序员宅基地

文章浏览阅读2k次。射频识别技术漫谈(6-10),概述RFID的通讯协议;射频ID卡的原理与实现,数据的传输与解码;介绍动物标签属性与数据传输;RFID识别号的变化等_芯片 ttf模式

Python 项目实战 —— 手把手教你使用 Django 框架实现支付宝付款_django 对接支付宝接口流程-程序员宅基地

文章浏览阅读1.1k次。今天小编心血来潮,为大家带来一个很有趣的项目,那就是使用 Python web 框架 Django 来实现支付宝支付,废话不多说,一起来看看如何实现吧。_django 对接支付宝接口流程

Zabbix 5.0 LTS在清理历史数据后最新数据不更新_zabbix问题没有更新-程序员宅基地

文章浏览阅读842次。Zabbix 5.0 LTS,跑了一年多了一直很稳定,前两天空间显示快满了,于是手贱清理了一下history_uint表(使用mysql truncate),结果折腾了一周。大概故障如下:然后zabbix论坛、各种群问了好久都没解决,最后自己一番折腾似乎搞定了。初步怀疑,应该是由于历史数据被清空后,zabbix需要去处理数据,但是数据量太大,跑不过来,所以来不及更新了(?)..._zabbix问题没有更新

python学习历程_基础知识(2day)-程序员宅基地

文章浏览阅读296次。一、数据结构之字典 key-value

mybatis-plus字段策略注解strategy_mybatisplus strategy-程序员宅基地

文章浏览阅读9.7k次,点赞3次,收藏13次。最近项目中遇到一个问题,是关于mybatis-plus的字段注解策略,记录一下。1问题调用了A组件(基础组件),来更新自身组件的数据,发现自己组件有个字段总是被清空。2原因分析调用的A组件的字段,属于基础字段,自己业务组件,对这个基础字段做了扩展,增加了业务字段。但是在自己的组件中的实体注解上,有一个注解使用错误。mybatis-plus封装的updateById方法,如果..._mybatisplus strategy

随便推点

信息检索笔记-索引构建_为某一文档及集构件词项索引时,可使用哪些索引构建方法-程序员宅基地

文章浏览阅读3.8k次。如何构建倒排索引,我们将这个过程叫做“索引构建”。如果我们的文档很多,这样索引就一次性装不下内存,该如何构建。硬件的限制 我们知道ram读写是随机的操作,只要输入相应的地址单元就能瞬间将数据读出来或者写进去。但是磁盘不行,磁盘必须有一个寻道的过程,外加一个旋转时间。那么只有涉及到磁盘,我们就可以考虑怎么节省I/O操作时间。【注】操作系统往往以数据块为单位进行读写。因为读一_为某一文档及集构件词项索引时,可使用哪些索引构建方法

IT巨头英特尔看好中国市场前景-程序员宅基地

文章浏览阅读836次。英特尔技术与制造事业部副总裁卞成刚7日在财富论坛间隙接受中新社记者采访时表示,该公司看好中国市场前景,扎根中国并以此走向世界是目前最重要的战略之一。卞成刚说,目前该公司正面临战略转型,即从传统PC服务领域扩展至所有智能设施领域,特别是移动终端。而中国目前正引领全球手机市场,预计未来手机、平板电脑等方面的发明创新将大量在中国市场涌现,并推向全球。持相同态度的还有英特尔中国区执行董事戈峻。戈峻

ceph中的radosgw相关总结_radosgw -c-程序员宅基地

文章浏览阅读627次。https://blog.csdn.net/zrs19800702/article/details/53101213http://blog.csdn.net/lzw06061139/article/details/51445311https://my.oschina.net/linuxhunter/blog/654080rgw 概述Ceph 通过radosgw提供RES..._radosgw -c

前端数据可视化ECharts使用指南——制作时间序列数据的可视化曲线_echarts 时间序列-程序员宅基地

文章浏览阅读3.7k次,点赞6次,收藏9次。我为什么选择ECharts ? 本周学校课程设计,原本随机佛系选了一个51单片机来做音乐播放器,结果在粗略玩了CN-DBpedia两天后才回过神,课设还没有开始整。于是懒癌发作,碍于身上还有比赛的作品没交,本菜鸡对硬件也没啥天赋,所以就直接把题目切换成软件方面的题目。写python的同学选择了一个时间序列数据的可视化曲线程序设计题目,果真python在数据可视化这一点性能很优秀。..._echarts 时间序列

ApplicationEventPublisherAware事件发布-程序员宅基地

文章浏览阅读1.6k次。事件类:/** * *   * @className: EarlyWarnPublishEvent *   * @description:数据风险预警发布事件 *   * @param: *   * @return: *   * @throws: *   * @author: lizz *   * @date: 2020/05/06 15:31 * */public cl..._applicationeventpublisheraware

自定义View实现仿朋友圈的图片查看器,缩放、双击、移动、回弹、下滑退出及动画等_imageview图片边界回弹-程序员宅基地

文章浏览阅读1.2k次。如需转载请注明出处!点击小图片转到图片查看的页面在Android开发中很常用到,抱着学习和分享的心态,在这里写下自己自定义的一个ImageView,可以实现类似微信朋友圈中查看图片的功能和效果。主要功能需求:1.缩放限制:自由缩放,有最大和最小的缩放限制 2居中显示:.若图片没充满整个ImageView,则缩放过程将图片居中 3.双击缩放:根据当前缩放的状态,双击放大两倍或缩小到原来 4.单指_imageview图片边界回弹