Pytorch使用交叉熵损失函数CrossEntrophy一些需要注意的细节_torch topk crossentrophy_Geronimo620的博客-程序员宅基地

技术标签: Pytorch  python  深度学习  pytorch  

Pytorch使用交叉熵损失函数CrossEntrophy一些需要注意的细节

CrossEntrophy()

交叉熵损失函数,是一种在多分类任务,多标签学习中效果较好的损失函数。

criterion = nn.CrossEntropyLoss()
...
# train
...
        for i, (features, length, label) in enumerate(train_loader):
            ...
            loss = criterion(prediction, target)
            loss.backward()
            optimizer.step()
            L = L + loss.item()

其中prediction和target分别代表网络的输出和数据自身的标签

我在进行分类网络搭建时,碰到两个相关问题。

1.RuntimeError: multi-target not supported at …

这是编译CrossEntropyLoss()时的报错信息,原因如下。

我们一般会将多任务网络输出设为一组以为概率分布,如下所示:

prediction = tensor([0.01.0.2,0.5,0.2,0.01,0.98])

自然而然,我们计算损失时会想到用相同的概率分布形式,即one-hot型编码

target = tensor([1,0,0,0,0,0])

但其实这种traget是错的,CrossEntropyLoss中的target输入将自动转化为one-hot形式。我们要输入的其实是

第几种类类型,即

target = tensor([0]) 
#  CrossEntropy()将自动转化为tensor([1,0,0,0,0,0])
target = tensor([5]) 
#  CrossEntropy()将自动转化为tensor([0,0,0,0,0,1])

也可以更改 dataloader 中 dataset 中 def getitem(``self, index) 返回的 target 的内容(将 one hot 格式改成 数字格式 就行)。

2.使用squeeze降维注意事项

这是我在使用squeeze方法为tensor降维时遇到的。

报错信息为RuntimeError: dimension specified as 0 but tensor has no dimensions

在训练的valid和test中,不像train中有batchsize概念,当每一个训练样本只对应一个label时。

如果train中batchsize为16,那么不同环节中CrossEntropyLoss的target的size如下:

#train
>>print(label.size())
>>[161]
#valid,test
>>print(label.size())
>>[11]
#squeeze降维之后
#train
>>print(label.squeeze.size())
>>[16]
#valid,test
>>print(label.squeeze.size())
>>[]#Loss计算时会报错

shape为[]类型的tensor进行Loss计算时会出现shape不匹配的错误。

这是因为squeeze()会默认去除所有大小为1 的维度

解决方法为:

label = label.squeeze(label,1)

这样可以指定去除第二维的维度(其大小为1)。具体可以查看squeeze()的使用方法。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_43175007/article/details/113763548

智能推荐

三种方法修改MySQL数据库中一个用户的密码_数据库中用户密码转码-程序员宅基地

三种方法修改MySQL数据库中一个用户的密码,在MySQL中修改一个用户(比如叫"hunte")的密码,可以用如下3个办法:  在MySQL中修改一个用户(比如叫"xxx")的密码,可以用如下3个办法:  1. 在控制台上输入    bash$ mysql -u root mysql    #用mysql客户程序    mysql> UPDATE user SET pass_数据库中用户密码转码

无名飞控姿态解算和控制(三)-程序员宅基地

继续码代码上一篇主要写了自稳模式下的代码流程,这次主要是飞控的定高和定点控制流程首先是定高控制模式在Main_Leading_Control里选择定高模式代码:else if(Controler_Mode==2)//定高模式 {/**************************定高模式,水平姿态期望角来源于遥控器****************************

vue学习-属性监听,样式绑定v-bind-程序员宅基地

1.监听属性例子1<div id = "app"> <p style = "font-size:25px;">计数器: {{ counter }}</p> <button @click = "counter++" styl

使用微软官方工具制作U盘系统重装盘_用微软官网制作的系统u盘是guid-程序员宅基地

准备工作官方给出的要求是需要一个至少 8GB 空间的空白U 盘或空白 DVD,这边使用U盘即可这边使用的金士顿U盘给它格式化了一次,记得格式化成FAT32格式的微软官网下载制作工具:点我进入点击立即下载工具1:推荐下载到自己能找到的目录,比如桌面2:下载完后右击选择以管理员方式运行3:点击接受微软协议进入下一步4:选择:为另一台电脑创建安装介质(U盘)当然你..._用微软官网制作的系统u盘是guid

android 解析apk包名,PC上查看/解析APK包名_钟聚湃的博客-程序员宅基地

E:\apktools>aapt d badging jj.apk | grep 'package:'package: name='com.roguerocketgames.m3s' versionCode='6' versionName='0.7.3'appt还有更多的功能,对apk的解析非常容易,可以查看它帮助Android Asset Packaging ToolUsage:aapt ..._解析apk的包名

foxmail总是删除服务器已接收文件,如何设置FOXMAIL收到邮件后,服务器上邮件自动删除...-程序员宅基地

如今的Foxmail 6可以帮我们定期清理垃圾邮件:单击菜单“工具/反垃圾邮件功能设置”,在打开的窗口中切换到“贝叶斯过滤”标签,勾选“自动删除垃圾邮件箱中以下天数之前的旧邮件”,并在下面设置好天数(如图1)。以后清除垃圾邮件就由Foxmail替你代劳吧。DreamMail单击菜单的“工具/选项”命令,在打开的窗口中再单击“常规选项”下的“自动删除”项,勾选“自动删除以下邮件夹列表中超过多少天的邮..._foxmail收取后删除服务器邮件

随便推点

利用XSS盗取cookies-程序员宅基地

利用XSS盗取cookies 1.把下里代码保存为cookies.asp,然后上传到能被正常访问的空间 <% testfile=Server.MapPath("cookies.txt") msg=Request("cookies") set fs=server.Cre...

Python输入三角形3条边长,a,b,c,判断是否构成三角形,如果构成三角形输出:三角形面积的面积为x;如果无法构成三角形,输出:无法构成三角形_python输入三角形的三条边长,判断能否构成三角形-程序员宅基地

a=float( input("a="))b=float( input("b="))c=float( input("c="))if a+b>c and b+c>a and a+c>b: p=(a+b+c)/2 s=(p*(p-a)*(p-b)*(p-c))**(1/2) print("三角形的面积为",s)else: print("无法构成三角形")_python输入三角形的三条边长,判断能否构成三角形

计算机应用基基本设置,《windows7操作系统基本操作(计算机应用基础)》.ppt_高鸣 蜡烛人挚友的博客-程序员宅基地

Windows7操作系统基本操作;主要内容;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;一、Windows家族的简史;二、 Windows7操作系统基本操作;二、 Windo...

Go 专栏|复合数据类型:字典 map 和 结构体 struct_go map struct_yongxinz的博客-程序员宅基地

楼下新开了一家重庆砂锅肥肠,扩音喇叭一直在放:正宗的老重庆砂锅肥肠,麻辣可口,老巴适了。正不正宗不知道,反正听口音,我以为我回东北了。本篇介绍复合数据类型的最后一篇:字典和结构体。内容很重要,编程时用的也多,需要熟练掌握才行。本文所有代码基于 go1.16.6 编写。字典字典是一种非常常用的数据结构,Go 中用关键词 map 表示,类型是 map[K]V。K 和 V 分别是字典的键和值的数据类型,其中键必须支持相等运算符,比如数字,字符串等。创建字典有两种方式可以创建字典,第一种是直接使用_go map struct

shell sort 最后一列排序_Linux/Shell:排名第四的计算机关键技能-程序员宅基地

除了编程语言之外,要想找一份计算机相关的工作,还需要很多其他方面的技能。最近,来自美国求职公司 Indeed 的一份报告显示:在全美工作技能需求中,Linux/Shell技能仅次于SQL、Java、Python,是排名第四的计算机关键技能,力压JavaScript。相对于SQL、Java、Python而言,Linux/Shell可谓是即简单又复杂。记住一个命令就能上手操作,看起来Shel..._linux sort 最后一列

主题:C/C++编译器的选用 _www.949-程序员宅基地

特别说明:鉴于时不时的有人问关于用什么编译器的问题,我翻译了Bjarne Stroustrup主页上compilers栏目的一篇文章,希望对大家有点指导意义。这个翻译稿的粘贴过程中失去了所有超级连接(pfan提供的编辑功能不够好),我只好另外在旁边附加上,抱歉。另外,为了避免重复发贴,我把本站sarrow原来的相关文章也复制过来,供彷徨中的朋友参考。一个C++编译器的不完全列表 _www.949

推荐文章

热门文章

相关标签