pytorch--自定义loss(-log BinaryCrossEntropy FocalLoss)_pytorch focal loss binary cross_entropy-程序员宅基地

技术标签: pytorch loss  深度学习  

  1. 负log loss;
  2. binary crossentropy;
  3. focal loss;

网上找到的loss写的都普遍复杂,我自己稍微写的逻辑简单一点。
注:这里没有考虑不参与计算loss的情况。

if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()

focal loss

focal loss仔细实践起来可以分为两种情况,一种是二分类(sigmoid激活)的时候,还有一种情况就是多分类(softmax激活)的时候。

二分类focal loss

在这里插入图片描述

class FocalLoss(nn.Module):
    """ -[alpha*y*(1-p)^gamma*log(p)+(1-alpha)(1-y)*p^gamma*log(1-p)] loss"""

    def __init__(self, gamma, alpha=None , onehot=False):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.onehot = onehot


    def forward(self, inputs, targets):
        """

        :param input: onehot
        :param target: 默认是onehot以后
        :return:
        """
        N = inputs.size(0)
        C = inputs.size(1)
        inputs = torch.clamp(inputs, min=0.001, max=1.0)  ##将一个张量中的数值限制在一个范围内,如限制在[0.1,1.0]范围内,可以避免一些运算错误,如预测结果q中元素可能为0
        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()    

         if not self.onehot:
            class_mask = inputs.data.new(N, C).fill_(0)
            class_mask = Variable(class_mask)
            ids = targets.view(-1, 1)
            class_mask.scatter_(1, ids.data, 1.)
            targets = class_mask

        pos_sample_loss_matrix = -targets * (torch.pow((1 - inputs), self.gamma)) * inputs.log()  ## 正样本的loss
        # mean_pos_sample_loss = pos_sample_loss_matrix.sum() / targets.sum()

        neg_sample_loss_matrix = -(targets == 0).float() * (torch.pow((inputs), self.gamma)) * (1 - inputs).log()  ## 负样本的loss
        # mean_neg_sample_loss = pos_sample_loss_matrix.sum() / targets.sum()
        if self.alpha:
            return (self.alpha * pos_sample_loss_matrix + (1 - self.alpha) * neg_sample_loss_matrix).sum() / (N * C)
        else:
            return (pos_sample_loss_matrix + neg_sample_loss_matrix).sum() / (N * C)


多分类focal loss

在这里插入图片描述

class FocalLoss(nn.Module):
    """ -[y*(1-p)^gamma*log(p) loss
        softmax激活输入的foacl loss。
    """

    def __init__(self, gamma, onehot=False):
        super(FocalLoss, self).__init__()

        self.gamma = gamma
        self.onehot = onehot


    def forward(self, inputs, targets):
        """

        :param input: onehot
        :param target: 默认是onehot以后
        :return:
        """

        inputs = torch.clamp(inputs, min=0.001, max=1.0)  ##将一个张量中的数值限制在一个范围内,如限制在[0.1,1.0]范围内,可以避免一些运算错误,如预测结果q中元素可能为0

         if not self.onehot:
            N = inputs.size(0)
            C = inputs.size(1)
            class_mask = inputs.data.new(N, C).fill_(0)
            class_mask = Variable(class_mask)
            ids = targets.view(-1, 1)
            class_mask.scatter_(1, ids.data, 1.)
            targets = class_mask

        pos_sample_loss_matrix = -targets * (torch.pow((1 - inputs), self.gamma)) * inputs.log()  ## 正样本的loss
        # mean_pos_sample_loss = pos_sample_loss_matrix.sum() / targets.sum()

        ## 默认输出均值
        ## 这里不能直接求mean,
        # 因为整个矩阵还是原来的输入大小的,
        # 求loss应该是除以label中有目标的总数。
        
        return pos_sample_loss_matrix / targets.sum()
        


代码

  • NegtiveLogLoss
  • BinaryCrossEntropy
import torch
import torch.nn as nn
from torch.autograd import Variable

class NegtiveLogLoss(nn.Module):
   """ -log(p) loss"""

   def __init__(self, onehot=False):
       super(NegtiveLogLoss, self).__init__()
       self.onehot = onehot

   def forward(self, inputs, targets):
       """

       :param input: onehot
       :param target: 默认是onehot以后
       :return:
       """

       inputs = torch.clamp(inputs, min=0.001, max=1.0)  ## 将一个张量中的数值限制在一个范围内,如限制在[0.1,1.0]范围内,可以避免一些运算错误,如预测结果q中元素可能为0

       if not self.onehot:
           N = inputs.size(0)
           C = inputs.size(1)
           class_mask = inputs.data.new(N, C).fill_(0)
           class_mask = Variable(class_mask)
           ids = targets.view(-1, 1)
           class_mask.scatter_(1, ids.data, 1.)
           targets = class_mask

       loss_matrix = -targets * inputs.log()  ## 对预测的矩阵里面的每个元素做log,
       ## 然后乘以one hot的label,也就是说获得1位置的值了。
       ## 这时候还是个矩阵,还没有计算均值
       ## 默认输出均值
       return loss_matrix.sum() / targets.sum()  ## 这里不能直接求mean,
       # 因为整个矩阵还是原来的输入大小的,
       # 求loss应该是除以label中有目标的总数。



class BinaryCrossEntropy(nn.Module):
   """ -(ylog(p)+(1-y)log(1-p) loss"""

   def __init__(self, alpha=None, onehot=False):
       super(BinaryCrossEntropy, self).__init__()
       self.alpha = alpha
       self.onehot = onehot

   def forward(self, inputs, targets):
       """

       :param input: onehot
       :param target: 默认是onehot以后
       :return:
       """
       N = inputs.size(0)
       C = inputs.size(1)
       inputs = torch.clamp(inputs, min=0.001, max=1.0)  ##将一个张量中的数值限制在一个范围内,如限制在[0.1,1.0]范围内,可以避免一些运算错误,如预测结果q中元素可能为0
       
       if not self.onehot:

           class_mask = inputs.data.new(N, C).fill_(0)
           class_mask = Variable(class_mask)
           ids = targets.view(-1, 1)
           class_mask.scatter_(1, ids.data, 1.)
           targets = class_mask

       pos_sample_loss_matrix = -targets * inputs.log()  ## 正样本的loss
       # mean_pos_sample_loss = pos_sample_loss_matrix.sum() / targets.sum()
      
       neg_sample_loss_matrix = -(targets == 0).float() * (1 - inputs).log()  ## 负样本的loss
       # mean_neg_sample_loss = pos_sample_loss_matrix.sum() / targets.sum()

       if self.alpha:
           return (self.alpha*pos_sample_loss_matrix + (1-self.alpha)*neg_sample_loss_matrix).sum() / (N * C)
       else:
           return (pos_sample_loss_matrix + neg_sample_loss_matrix).sum() / (N * C)  



引用

  • http://kodgv.xyz/2019/04/22/%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/FocalLoss%E9%92%88%E5%AF%B9%E4%B8%8D%E5%B9%B3%E8%A1%A1%E6%95%B0%E6%8D%AE/
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u012925804/article/details/103032408

智能推荐

html自定义的DIV垂直滚动条-程序员宅基地

文章浏览阅读1.8w次,点赞6次,收藏18次。首先说一下自定义滚动条的一个要求:鼠标滚动在它div上滚动时,如果没有滚到顶端或底部则不能影响页面滚动条和系统自带一样让一个div拥有滚动条1、只有垂直滚动条#mydiv1{ position: relative; overflow-x: hidden; overflow-y: scroll; width: 100px;_div垂直滚动条

深入理解Java虚拟机 精华总结(面试)_深入理解java虚拟机 reentrantlock-程序员宅基地

文章浏览阅读208次。一.运行时数据区域  Java虚拟机管理的内存包括几个运行时数据内存:方法区、虚拟机栈、堆、本地方法栈、程序计数器,其中方法区和堆是由线程共享的数据区,其他几个是线程隔离的数据区。  1.1程序计数器  程序计数器是一块较小的内存,他可以看做是当前线程所执行的行号指示器。字节码解释器工作的时候就是通过改变这个计数器的值来选取下一条需要执行的字节码的指令,分支、循环、跳转、异常处理、线程恢复等..._深入理解java虚拟机 reentrantlock

Ubuntu16.04打开和关闭桌面显示_ubuntu如何关闭lightdm-程序员宅基地

文章浏览阅读4.5k次,点赞2次,收藏8次。阿里云Ubuntu16.04打开和关闭桌面显示(201910122)文章目录一、打开桌面图像化显示一、关闭桌面图像化显示在进行阿里云远程连接桌面(浏览中显示的桌面)时,通过点击左上角的的alt+ctrl+f1,然后进入到了terminal界面,之后想再进入桌面显示界面不知道怎么怎么操作啦,下面教你如何在terminal下sao操作,进入图形化界面!!!一、打开桌面图像化显示在命令行中输入..._ubuntu如何关闭lightdm

Java Excel导入导出工具类_excel导入导出工具类java-程序员宅基地

文章浏览阅读649次,点赞2次,收藏2次。Java Excel导入导出工具类第一步导入maven依赖 <dependency> <groupId>org.apache.poi</groupId> <artifactId>poi</artifactId> <version>3.14</version> </dependency> <dependency> <groupId_excel导入导出工具类java

BIO/NIO/AIO-程序员宅基地

文章浏览阅读127次。面试官:聊聊BIO、NIO、AIO我:emmm......我只知道IO虽然IO流是Java基础部分学习的内容,而且用起来也是比较简单的;但是,如果让你系统的说一下IO,还是比较困难的。这篇博客通过各方面对比,来聊一聊这几个IO。一些概念在学习Java的IO流之前,我们必须要知道一些关键词。同步与异步:(关注的是消息通信机制) 同步(Synchronous):代码按顺序执行,执行到同步方法时,不管方法有没有返回值都要执行完,才能往下执行。 异步(Asynchronou._bio/nio/aio

探索 `dot-prop`: 简洁而强大的JavaScript对象路径处理库-程序员宅基地

文章浏览阅读378次,点赞4次,收藏6次。探索 dot-prop: 简洁而强大的JavaScript对象路径处理库项目地址:https://gitcode.com/sindresorhus/dot-prop在开发JavaScript应用时,我们经常会遇到需要操作嵌套对象属性的情况。dot-prop 是一个轻量级、高效的库,专门用于处理这样的任务。它使得我们可以以简洁的方式访问、设置或删除对象路径中的属性,从而简化了代码逻辑。项目简介...

随便推点

easyuefi只能在基于uefi启动的_取代传统BIOS的EFI和UEFI究竟是什么?要如何设置?...-程序员宅基地

文章浏览阅读129次。前段时间写了一篇安装win10的详细教程。没想到还引起了不少网友的关注。如今电脑也算是人人都离不开的生产力工具。自己有个重装系统的手艺,不但可以“防身”。还可以坐等女神找你修电脑,岂不乐哉?系统安装的步骤在之前的文章中已经详细介绍了。但有一点没有在文中说清楚:就是BIOS、UEFI、MBR、GPT这四者的关系。对于小白来说不明白这些概念,你有可能对老一点电脑就束手无策了。本文就是来补足之前缺憾的。..._easy uefi

stm32学习—CAN_can nart-程序员宅基地

文章浏览阅读314次。CAN总线调试心得端口复用P113//USART1_TX PA.9 复用推挽输出GPIO_InitStructure.GPIO_Pin = GPIO_Pin_9; //PA.9GPIO_InitStructure.GPIO_Speed = GPIO_Speed_50MHz;GPIO_InitStructure.GPIO_Mode = GPIO_Mode_AF_PP; //复用推挽输出GPIO_Init(GPIOA, &GPIO_InitStructure);//USART1_RX _can nart

influxdb_influxdb可以存多少条数据-程序员宅基地

文章浏览阅读349次。MAC安装influxdbbrew updatebrew install influxdbln -sfv /usr/local/opt/influxdb/*.plist ~/Library/LaunchAgents#配置文件在/etc/influxdb/influxdb.conf ,如果没有就将/usr/local/etc/influxdb.conf 拷一个过去配置缓存:cache-max-memory-size#启动服务launchctl load ~/Library/LaunchAgent_influxdb可以存多少条数据

FPGA编程神器,亲测好用,非常舒适!_fpga助手-程序员宅基地

文章浏览阅读1.8k次。语言高亮、文件标志、定义跳转、悬停提示、工程结构、语法诊断、自动对齐、自动补全、语言翻译、状态预览、自动生成tb、快速例化、vivado快速仿真、iverilog快速仿真、支持常见功能库、vivado开发辅助、ZYNQ开发辅助_fpga助手

【opencv图像处理】 05图像操作_dst.create(src.size()src.type());-程序员宅基地

文章浏览阅读732次,点赞2次,收藏4次。代码实现 功能:显示一张图片 #include &lt;opencv2/opencv.hpp&gt; #include &lt;iostream&gt; using namespace std; using namespace cv; /***************************************..._dst.create(src.size()src.type());

【计算机毕业设计】051网上医院预约挂号系统_医院预约挂号系统毕设-程序员宅基地

文章浏览阅读708次,点赞3次,收藏8次。如今的信息时代,对信息的共享性,信息的流通性有着较高要求,因此传统管理方式就不适合。为了让医院预约挂号信息的管理模式进行升级,也为了更好的维护医院预约挂号信息,网上医院预约挂号系统的开发运用就显得很有必要。并且通过开发网上医院预约挂号系统,不仅可以让所学的SSM框架得到实际运用,也可以掌握MySQL的使用方法,对自身编程能力也有一个检验和提升的过程。尤其是通过实践,可以对系统的开发流程加深印象,无论是前期的分析与设计,还是后期的编码测试等环节,都可以有一个深刻的了解。网上医院预约挂号系统根据调研,确定管_医院预约挂号系统毕设

推荐文章

热门文章

相关标签