pytorch查看loss曲线_Pytorch里的CrossEntropyLoss详解_weixin_39932838的博客-程序员秘密

技术标签: pytorch查看loss曲线  

在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我头大,所以整理本文以备日后查阅。

首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?

下面是对与cross entropy有关的函数做的总结:

torch.nn

torch.nn.functional (F)

CrossEntropyLoss

cross_entropy

LogSoftmax

log_softmax

NLLLoss

nll_loss

下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。

在介绍cross_entropy之前先介绍两个基本函数:

log_softmax

这个很好理解,其实就是log和softmax合并在一起执行。

nll_loss

该函数的全程是negative log likelihood loss,函数表达式为

\[f(x,class)=-x[class]

\]

例如假设\(x=[1,2,3], class=2\),那额\(f(x,class)=-x[2]=-3\)

cross_entropy

交叉熵的计算公式为:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)

\]

其中\(p\)表示真实值,在这个公式中是one-hot形式;\(q\)是预测值,在这里假设已经是经过softmax后的结果了。

仔细观察可以知道,因为\(p\)的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)=-log \, q_m

\]

仔细看看,是不是就是等同于log_softmax和nll_loss两个步骤。

所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmax和nll_loss来计算交叉熵,其计算方式如下:

\[\operatorname{loss}(x, \text {class})=-\log \left(\frac{\exp (x[\operatorname{class}])}{\sum_{j} \exp (x[j])}\right)

\]

代码示例

>>> input = torch.randn(3, 5, requires_grad=True)

>>> target = torch.randint(5, (3,), dtype=torch.int64)

>>> loss = F.cross_entropy(input, target)

>>> loss.backward()

微信公众号:AutoML机器学习

MARSGGBO原创

如有意合作或学术讨论欢迎私戳联系~

邮箱:[email protected]

2019-2-19

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_39932838/article/details/111758513

智能推荐

datetime库方法总结_datetime.fromtimestamp tz_小笼包xiaolongbao的博客-程序员秘密

python datetime处理时间https://www.cnblogs.com/lhj588/archive/2012/04/23/2466653.htmlPython提供了多个内置模块用于操作日期时间,像calendar,time,datetime。datetime模块定义了两个常量:datetime.MINYEAR和datetime.MAXYEAR,分别表示datetime所能表示...

最适合Java初学者练手的项目【JavaSE项目-图书管理系统】_牛仔码农@的博客-程序员秘密

图书馆管理小项目的主要目的是让学完JavaSE的同学对之前掌握的知识做一些运用,该项目应用的知识点包括下面内容:面向对象的思想 分层的思想 接口 异常 集合 日期处理 Stream流 IO流 反射 Javafx(了解) css(了解)通过学习本项目,可以巩固JavaSE的知识,对于后续的学习来说可以起到很好的衔接。视频观看效果更佳,点击以下链接????1.2环境搭建1.2.1基本信息开发工具:IDEAJDK版本:8项目编码:GBK1.2.2使用技术.

映射 is not mapped的问题_Favoritebook的博客-程序员秘密

报了这个错 映射 is not mapped的问题  大小写也改过了 还是不行 原来是实体类  没有映射注解 这亏吃爆!添上之后 走起

Twitter的分布式自增ID算法snowflake (PHP版本)_LoweMuo的博客-程序员秘密

twitter的snowflake解决了这种需求,最初Twitter把存储系统从MySQL迁移到Cassandra,因为Cassandra没有顺序ID生成机制,所以开发了这样一套全局唯一ID生成服务。snowflake的结构如下(每部分用-分开):0 - 0000000000 0000000000 0000000000 0000000000 0 - 00000 - 00000 - 000000...

LeetCode-Python-977. 有序数组的平方_暴躁老哥在线刷题的博客-程序员秘密

给定一个按非递减顺序排序的整数数组 A,返回每个数字的平方组成的新数组,要求也按非递减顺序排序。 示例 1:输入:[-4,-1,0,3,10]输出:[0,1,9,16,100]示例 2:输入:[-7,-3,2,3,11]输出:[4,9,9,49,121]思路:水题, 按要求弄就完事了。class Solution(object): def sor...

Collada Exporter114 工具导出COLLADA 格式3dmax打不开_还是落叶草的博客-程序员秘密

一个是Autodesk版本的DAE 一个是开源的https://github.com/KhronosGroup/OpenCOLLADA/wiki/OpenCOLLADA-Tools 要下载插件

随便推点

文章_A诺亚方舟A的博客-程序员秘密

   好久没有遇到上瘾的事情了,从开始的钓鱼,玩游戏,看直播,慢慢的都被放下了。我突然感觉自己是不是变老了,对新事物也不报有太多期待,很想培养一种新的爱好,就像当初风雨无阻的去钓鱼,去通宵一样,生活总点有一样自己非常喜欢做的事吧。都说写东西会上瘾,我也来试一试。读了六年小学,三年初中,四年高中,作文只被老师表扬过一次,至今记忆尤新,还是我不喜欢的语文老师。。。。。。。,很想知道写作是如何锻炼一个人...

整数反转_卡布达吃西瓜的博客-程序员秘密

给出一个 32 位的有符号整数,你需要将这个整数中每位上的数字进行反转。示例 1:输入: 123输出: 321示例 2:输入: -123输出: -321示例 3:输入: 120输出: 21注意:假设我们的环境只能存储得下 32 位的有符号整数,则其数值范围为 [−231, 231 − 1]。请根据这个假设,如果反转后整数溢出那么就返回 0。通过次数347,529提交次数1,018,176来源:力扣(LeetCode)链接:https://leetcode-cn.com/prob

数据库设计,数据库性能优化(teched 2008讲义)_weixin_34205076的博客-程序员秘密

讲义内容非常好,还有一些最佳实践,为ms sql oltp系统性能调优指明了方向。oltp_sql_performance.pdf

Java CompletableFuture.allOf() 、thenApplyAsync()、thenRun()等使用_HSJ0170的博客-程序员秘密

package java8;//import com.google.common.util.concurrent.ThreadFactoryBuilder;import java.time.Instant;import java.time.temporal.ChronoUnit;import java.util.ArrayList;import java.util.List;import java.util.Random;import java.util.concurrent.*;im

手把手教你编译、调试TVM-程序员秘密

本文介绍从ubuntu 18.04搭建TVM的编译、调试环境,实现以下目标: - TVM源码编译 - vscode中实现TVM源码中Python/C++接口的定义跳转 - vscode中实现TVM在Python、C++代码中的断点调试

推荐文章

热门文章

相关标签