YOLO 系列损失函数理解_weixin_30633507的博客-程序员秘密

技术标签: 人工智能  

YOLO V1损失函数理解:

                                                                                 (结构图)

首先是理论部分,YOLO网络的实现这里就不赘述,这里主要解析YOLO损失函数这一部分。
在这里插入图片描述

损失函数分为三个部分:

代表cell中含有真实物体的中心。 pr(object) = 1

① 坐标误差

 为什么宽和高要带根号???

对不同大小的bbox预测中,相比于大bbox预测偏一点,小box预测偏一点更不能忍受。作者用了一个比较取巧的办法,就是将box的width和height取平方根代替原本的height和width

(主要为了平衡小目标检测预测的偏移)

② IOU误差(很多人不知道代表什么

其实这里的分别表示 1  和 0       =   

③ 分类误差

这个很容易理解(激活函数的输出)。

 

下面给出TensorFlow的Loss代码:

 1     def loss_layer(self,predicts,labels,scope='loss'):
 2         ''' predicts的shape是[batch,7*7*(20+5*2)]
 3             labels的shape是[batch,7,7,(5+20)]
 4             '''
 5         with tf.variable_scope(scope):
 6             #预测种类,boxes置信度,boxes坐标[x_center,y_center,w,h],坐标都除以image_size归一化,中心点坐标为偏移量,
 7             #w,h归一化后又开方,目的是使变化更平缓
 8             predict_classes=tf.reshape(predicts[:,:self.boundary1],
 9                                       [self.batch_size,self.cell_size,self.cell_size,self.num_classes])
10             predict_scales=tf.reshape(predicts[:,self.boundary1:self.boundary2],
11                                      [self.batch_size,self.cell_size,self.cell_size,self.box_per_cell])
12             predict_boxes=tf.reshape(predicts[:,self.boundary2:],
13                                     [self.batch_size,self.cell_size,self.cell_size,self.box_per_cell,4])
14             #是否有目标的置信度
15             response=tf.reshape(labels[:,:,:,0],
16                                [self.batch_size,self.cell_size,self.cell_size,1])
17             #boxes坐标处理变成[batch,7,7,2,4],两个box最终只选一个最高的,为了使预测更准确
18             boxes=tf.reshape(labels[:,:,:,1:5],
19                             [self.batch_size,self.cell_size,self.cell_size,1,4])
20             boxes=tf.tile(boxes,[1,1,1,self.box_per_cell,1])/self.image_size
21             classes=labels[:,:,:,5:]
22             #offset形如[[[0,0],[1,1]...[6,6]],[[0,0]...[6,6]]...]与偏移量x相加
23             #offset转置形如[[0,0,[0,0]...],[[1,1],[1,1]...],[[6,6]...]]与偏移量y相加
24             #组成中心点坐标shpe[batch,7,7,2]是归一化后的值
25             offset=tf.constant(self.offset,dtype=tf.float32)
26             offset=tf.reshape(offset,[1,self.cell_size,self.cell_size,self.box_per_cell])
27             offset=tf.tile(offset,[self.batch_size,1,1,1])
28             
29             predict_boxes_tran=tf.stack([(predict_boxes[:,:,:,:,0]+offset)/self.cell_size,
30                                        (predict_boxes[:,:,:,:,1]+tf.transpose(offset,(0,2,1,3)))/self.cell_size,
31                                         tf.square(predict_boxes[:,:,:,:,2]),
32                                          tf.square(predict_boxes[:,:,:,:,3])],axis=-1)
33             #iou的shape是[batch,7,7,2]
34             iou_predict_truth=self.cal_iou(predict_boxes_tran,boxes)
35             #两个预选框中iou最大的
36             object_mask=tf.reduce_max(iou_predict_truth,3,keep_dims=True)
37             #真实图中有预选框,并且值在两个预选框中最大的遮罩
38             object_mask=tf.cast((iou_predict_truth>=object_mask),tf.float32)*response
39             #无预选框遮罩
40             noobject_mask=tf.ones_like(object_mask,dtype=tf.float32)-object_mask
41             #真实boxes的偏移量
42             boxes_tran=tf.stack([boxes[:,:,:,:,0]*self.cell_size-offset,
43                                 boxes[:,:,:,:,1]*self.cell_size-tf.transpose(offset,(0,2,1,3)),
44                                 tf.sqrt(boxes[:,:,:,:,2]),
45                                 tf.sqrt(boxes[:,:,:,:,3])],axis=-1)
#=================================================================================================================================
46 #分类损失 47 class_delta=response*(predict_classes-classes) 48 class_loss=tf.reduce_mean(tf.reduce_sum(tf.square(class_delta),axis=[1,2,3]),name='clss_loss')*self.class_scale 49 #有目标损失(IOU) 50 object_delta=object_mask*(predict_scales-iou_predict_truth) #这里iou_predict_truth应该为1 51 object_loss=tf.reduce_mean(tf.reduce_sum(tf.square(object_delta),axis=[1,2,3]),name='object_loss')*self.object_scale 52 #无目标损失(IOU) 53 noobject_delta=noobject_mask*predict_scales #这里减0 54 noobject_loss=tf.reduce_mean(tf.reduce_sum(tf.square(noobject_delta),axis=[1,2,3]),name='noobject_loss')*self.no_object_scale 55 #选框损失(坐标) 56 coord_mask=tf.expand_dims(object_mask,4) 57 boxes_delta=coord_mask*(predict_boxes-boxes_tran) 58 coord_loss=tf.reduce_mean(tf.reduce_sum(tf.square(boxes_delta),axis=[1,2,3,4]),name='coord_loss')*self.coord_scale 59 tf.losses.add_loss(class_loss) 60 tf.losses.add_loss(object_loss) 61 tf.losses.add_loss(noobject_loss) 62 tf.losses.add_loss(coord_loss)

 YOLO V2:

 

YOLO V3:

 

YOLOv3不使用Softmax对每个框进行分类,而使用多个logistic分类器,因为Softmax不适用于多标签分类,用独立的多个logistic分类器准确率也不会下降。

分类损失采用binary cross-entropy loss.

转载于:https://www.cnblogs.com/WSX1994/p/11226012.html

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_30633507/article/details/99817570

智能推荐

嵌入式入门级学习——国产平台T3开发板测试教程(1)_Tronlong创龙的博客-程序员秘密_国产嵌入式开发

本文主要为嵌入式入门开发者的接口、网口等板卡基础快速测试,当初级学习的开发者拿到板卡,如何在最快时间内,测试这个板卡的基础性能、功能是正常的,就让我们下面看看。该篇文章主要提供基于创龙科技TLT3-EVM评估板的硬件资源测试方法。无特殊说明情况下,默认使用USB TO UART0作为调试串口,使用Linux系统启动卡(Micro SD方式)启动系统,通过路由器与PC机进行网络连接。Linux系统启动卡对应的设备节点为mmcblk1,eMMC对应的设备节点为mmcblk0。本指导文档适用开发环境:

使用vue-pdf预览合同的pdf文件时,不显示签章_小小白号的博客-程序员秘密_vue-pdf不显示盖章

预览pdf不显示合同上的盖章和签字:1. 在node_modules文件夹中搜索pdf.worker.js文件2. 注释掉_this.setFlages(_util.AnnotationFlag.HIDDEN); 在31690行

Maven项目Module互相调用找不到Jar包(即使已经打包和添加依赖)的解决方案_米菲尔Miffeel的博客-程序员秘密

1.在被依赖的模块的pom.xml文件中添加以下内容:<build> <plugins> <plugin> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-maven-plugin</artifa...

EXT学习总结_#菜鸟架构师文标#的博客-程序员秘密

ext概述:    ExtJs初期仅是对Yahoo! UI的对话框扩展,后来逐渐有了自己的特色,深受网友的喜爱。 发展至今, Ext除YUI外还支持Jquery、Prototype等的多种JS底层库,让大家自由地选择。该框架完全基于纯Html/CSS+JS技术,提供丰富的跨浏览器UI组件,灵活采用JSON/XML数据源开发,使得服务端表示层的负荷真正减轻,从而达到客户端的MVC应用!E

Opengl鼠标交互函数glutMouseFunc()函数介绍_量化橙同学的博客-程序员秘密_glutmousefunc

检测鼠标单击要想在OpenGL中处理鼠标事件非常的方便,GLUT已经为我们的注册好了函数,只要我们提供一个方法。使用函数glutMouseFunc,就可以帮我们注册我们的函数,这样当发生鼠标事件时就会自动调用我们的方法。函数的原型是:void glutMouseFunc(void(*func)(int button,int state,int x,int y));参数:fu

微信 H5 音乐项目总结_weixin_34221332的博客-程序员秘密

H5 音乐项目总结刚刚完成了一个 H5 项目,途中使用 audio 的时候遇到不少坑,所以写篇项目总结。项目需求要经过微信授权才能进入。所以只能在微信打开。流程:开场有个小的过渡效果,有 bgm接着连续两张图片显示,有各自的 bgm第二张图片,有文字,文字的显示要有打字的效果,附带 bgm主场面拥有各个小物品,...

随便推点

迁移学习笔记3: TCA, Finetune, 与Triplet Network(元学习)_lagoon_lala的博客-程序员秘密

MotivationTCA, Finetune, Triplet NetworkTCAFinetuneTriplet Network方法对比总结迁移学习基于特征的迁移学习方法 (Feature based)基于模型的迁移学习方法 (Model based)元学习(multi-task)元学习(Meta Learning)与迁移学习(Transfer Learning)的区别联系元学习特点元学习种类

基于SkyWalking实现对k8s集群中微服务的链路追踪分析_最美dee时光的博客-程序员秘密_k8s链路追踪

基于SkyWalking实现对k8s集群中微服务的链路追踪分析背景需求:SkyWalking介绍:实现方式:step1:制作SkyWalking Agent镜像1、准备文件:2、编写skywalking agent的dockerfile文件3、构建镜像step2:编写java服务接入skywalking agent的yamlstep3:执行yaml文件step4:效果图背景需求:由于我司之前的服务都是部署在ECS中,对于java微服务的实时链路分析是基于skytwalking agent来做监控的,但是

ant design vue 2.2.8 表格table 序号添加时发现的customRender函数的参数问题_Vinzune的博客-程序员秘密

按照官网的描述写的代码{ title: "序号", dataIndex: "index", key: "index", align: "center", width: 70, customRender: (text: any,records:any,index:number) => `${index + 1}`, },这里官网的描述是:实际通过log打印却发现函数只有一个参数,是一个完整的对象包含了 text、record、,index,c

validation参数检验 - 如何使用_赵丙双的博客-程序员秘密_validation

文章目录Maven 依赖Spring MVC Controller 的输入验证 Path Variables、 Request Parameters、Request Header验证 RequestBody非 Controller 组件的方法自定义 Validator自定义一个验证需要的注解自定义 Constraint 注解注意点自定义一个 Validator定义一个 POJO 进行验证以纯代码方式使用 Validator,不依赖 Spring 的 `@Validate` 注解纯代码方式Spring 的非注

移动端适配(2)——viewport适配_powerx_yc的博客-程序员秘密

通过viewport来适配<script>(function(){  var w=window.screen.width;  console.log(w);//获取屏幕尺寸  var targetW=320;//之后所有的都是按照320来做  var scale=w/targetW;//缩放值:当前尺寸/目标尺寸  var meta=document.cre...

EXT JS的优点_沙子揉碎在眼睛里的博客-程序员秘密_ext js

跨浏览器支持只要你做 web 开发,你一定解决过浏览器兼容问题,这有多么消耗时间和影响工作进度你一定有体会,你可能会花上几小时甚至几天来解决跨浏览器 bug。你为什么不把这些时间用来关注你的业务功能呢?相反,如果使用 Ext JS 这个 JavaScript 框架,这种事情将由它去考虑,而你可以专注于业务功能。丰富的 UI 组件Ext JS 提供了大量丰富的 UI 控件,如 data grid ,...

推荐文章

热门文章

相关标签