yolov5损失函数的几点理解_yolov5 为什么sigmod()*2-程序员宅基地

技术标签: 重要  

yolov5损失函数的几点理解

所用代码:https://github.com/ultralytics/yolov5
参考文献:https://www.cnblogs.com/pprp/p/12590801.html
感谢知乎网友:Ancy贝贝

重要的代码块在build_targets内。

def build_targets(p, targets, model):
    # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
    det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() module
    na, nt = det.na, targets.shape[0]  # number of anchors, targets
    tcls, tbox, indices, anch = [], [], [], []
    gain = torch.ones(7, device=targets.device)  # normalized to gridspace gain
    ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
    #将targets复制3份,每份分配一个anchor编号,如0,1,2. 也就是每个anchor分配一份targets。
    targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices

    g = 0.5  # bias
    # 这里off表示了5个偏移,原点不动,往右、往下、往左、往上。
    # 其中坐标原点在图像的左上角,x轴往右(列),y轴往下(行)。
    off = torch.tensor([[0, 0],
                        [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                        # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                        ], device=targets.device).float() * g  # offsets

    for i in range(det.nl):
        #det.anchors在导入model的时候就除以了步长,因此此时anchor大小不是相对于原图,而是相对于对应特征层的尺寸
        anchors = det.anchors[i]
        gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain

        # Match targets to anchors
        #这里主要是将gt的cx,cy,w,h换算到当前特征层对应的尺寸,以便和该层的anchor大小相对应
        t = targets * gain
        if nt:
            # Matches
            #这个部分是计算gt和anchor的匹配程度
            #即w_gt/w_anchor  h_gt/h_anchor
            r = t[:, :, 4:6] / anchors[:, None]  # wh ratio
            #这里判断了r和1/r与model.hyp['anchor_t']的大小关系,即只有不大于这个数,也就是说gt与anchor的宽高差距不过大的时候,才认为匹配。代码中 model.hyp['anchor_t']=4
            j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']  # compare       
            # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
			#将满足条件的targets筛选出来。          
            t = t[j]  # filter
            
            # Offsets
            #这个部分就是扩充targets的数量,将比较targets附近的4个点,选取最近的2个点作为新targets中心,新targets的w、h使用与原targets一致,只是中心点坐标的不同。
            gxy = t[:, 2:4]  # grid xy
            gxi = gain[[2, 3]] - gxy  # inverse
            j, k = ((gxy % 1. < g) & (gxy > 1.)).T
            l, m = ((gxi % 1. < g) & (gxi > 1.)).T
            j = torch.stack((torch.ones_like(j), j, k, l, m))
            t = t.repeat((5, 1, 1))[j] #筛选后t的数量是原来t的3倍。
            offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
        else:
            t = targets[0]
            offsets = 0

        # Define
        b, c = t[:, :2].long().T  # image, class
        gxy = t[:, 2:4]  # grid xy
        gwh = t[:, 4:6]  # grid wh
        gij = (gxy - offsets) #自己加的代码,方便查看gij的分布。
         plot_gxy(gxy=gij, scale_i=i, size=gain, flag='gij') #自己编的代码,用于查看gij的分布。
        gij = (gxy - offsets).long() #将所有targets中心点坐标进行偏移。

        gi, gj = gij.T  # grid xy indices

        # Append
        a = t[:, 6].long()  # anchor indices
        indices.append((b, a, gj, gi))  # image, anchor, grid indices
        tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
        anch.append(anchors[a])  # anchors
        tcls.append(c)  # class

    return tcls, tbox, indices, anch

下图是20x20的特征图上的gij的分布示意图,从图中可以看出每个targets都扩充了2个临近的targets。关于为什么扩充,我还没理解,有知道的网友欢迎留言。另外,知乎网友Ancy贝贝的理解是:之前通过筛选,去掉了一些匹配不上anchor的gt,本来正样本就比负样本少很多,经过筛选,少得更多了,所以每个gt扩充2个出来,增加正样本比例。
在这里插入图片描述

    # Regression
    pxy = ps[:, :2].sigmoid() * 2. - 0.5
    pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
    pbox = torch.cat((pxy, pwh), 1).to(device)  # predicted box
    giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # giou(prediction, target)
    lbox += (1.0 - giou).mean()  # giou loss

在这里插入图片描述
代码中的pxy对应bxy,ps[:, :2]对应txy。由此可知bxy的取值范围是[-0.5,1.5]。因此有可能偏移到临近的单元格内,但偏移不多,不知道作者是什么考虑的。

在这里插入图片描述
代码中的pwh对应bwh,anchors[i]对应Pwh。因此可知bwh的范围是[0,4]*Pwh。这和前面
j = torch.max(r, 1. / r).max(2)[0] < model.hyp[‘anchor_t’] # model.hyp[‘anchor_t’]=4 是一致的。

Objectness

tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
此处 tobj[b, a, gj, gi]用giou(真实的是ciou)取代1,代表该点对应置信度。为什么要用giou来代替,我也没想明白,有知道的网友欢迎留言。

其余的部分比较好理解,在此不再赘述。

附:
plot_gxy的代码:

def plot_gxy(gxy, scale_i, size, flag):
    s = int(size[2].cpu().numpy())
    ax = plt.subplot(111)
    ax.axis([0, s, 0, s])
    lxx = np.arange(0, s + 1, 1)
    lxx = np.repeat(lxx, s + 1, axis=0)
    lxx = lxx.reshape(s + 1, s + 1)
    lyy = np.arange(0, s + 1, 1)
    lyy = np.repeat(lyy, s + 1, axis=0)
    lyy = lyy.reshape(s + 1, s + 1)
    lyy = lyy.T


    for i in range(len(lxx)):
        plt.plot(lxx[i], lyy[i], color='k', linewidth=0.05, linestyle='-')
        plt.plot(lyy[i], lxx[i], color='k', linewidth=0.05, linestyle='-')

    for i in range(len(gxy)):
        x1, y1 = gxy.cpu().numpy().T
        plt.scatter(x1, y1, s=0.02, color='k')

    ax = plt.gca()  # 获取到当前坐标轴信息
    ax.xaxis.set_ticks_position('top')  # 将X坐标轴移到上面
    ax.invert_yaxis()
    plt.savefig("gxy_{}_{}.png".format(scale_i, flag))
    plt.close()
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/tpz789/article/details/108844004

智能推荐

Spring Boot 获取 bean 的 3 种方式!还有谁不会?,Java面试官_springboot2.7获取bean-程序员宅基地

文章浏览阅读1.2k次,点赞35次,收藏18次。AutowiredPostConstruct 注释用于在依赖关系注入完成之后需要执行的方法上,以执行任何初始化。此方法必须在将类放入服务之前调用。支持依赖关系注入的所有类都必须支持此注释。即使类没有请求注入任何资源,用 PostConstruct 注释的方法也必须被调用。只有一个方法可以用此注释进行注释。_springboot2.7获取bean

Logistic Regression Java程序_logisticregression java-程序员宅基地

文章浏览阅读2.1k次。理论介绍 节点定义package logistic;public class Instance { public int label; public double[] x; public Instance(){} public Instance(int label,double[] x){ this.label = label; th_logisticregression java

linux文件误删除该如何恢复?,2024年最新Linux运维开发知识点-程序员宅基地

文章浏览阅读981次,点赞21次,收藏18次。本书是获得了很多读者好评的Linux经典畅销书**《Linux从入门到精通》的第2版**。下面我们来进行文件的恢复,执行下文中的lsof命令,在其返回结果中我们可以看到test-recovery.txt (deleted)被删除了,但是其存在一个进程tail使用它,tail进程的进程编号是1535。我们看到文件名为3的文件,就是我们刚刚“误删除”的文件,所以我们使用下面的cp命令把它恢复回去。命令进入该进程的文件目录下,1535是tail进程的进程id,这个文件目录里包含了若干该进程正在打开使用的文件。

流媒体协议之RTMP详解-程序员宅基地

文章浏览阅读10w+次,点赞12次,收藏72次。RTMP(Real Time Messaging Protocol)实时消息传输协议是Adobe公司提出得一种媒体流传输协议,其提供了一个双向得通道消息服务,意图在通信端之间传递带有时间信息得视频、音频和数据消息流,其通过对不同类型得消息分配不同得优先级,进而在网传能力限制下确定各种消息得传输次序。_rtmp

微型计算机2017年12月下,2017年12月计算机一级MSOffice考试习题(二)-程序员宅基地

文章浏览阅读64次。2017年12月的计算机等级考试将要来临!出国留学网为考生们整理了2017年12月计算机一级MSOffice考试习题,希望能帮到大家,想了解更多计算机等级考试消息,请关注我们,我们会第一时间更新。2017年12月计算机一级MSOffice考试习题(二)一、单选题1). 计算机最主要的工作特点是( )。A.存储程序与自动控制B.高速度与高精度C.可靠性与可用性D.有记忆能力正确答案:A答案解析:计算...

20210415web渗透学习之Mysqludf提权(二)(胃肠炎住院期间转)_the provided input file '/usr/share/metasploit-fra-程序员宅基地

文章浏览阅读356次。在学MYSQL的时候刚刚好看到了这个提权,很久之前用过别人现成的,但是一直时间没去细想, 这次就自己复现学习下。 0x00 UDF 什么是UDF? UDF (user defined function),即用户自定义函数。是通过添加新函数,对MySQL的功能进行扩充,就像使..._the provided input file '/usr/share/metasploit-framework/data/exploits/mysql

随便推点

webService详细-程序员宅基地

文章浏览阅读3.1w次,点赞71次,收藏485次。webService一 WebService概述1.1 WebService是什么WebService是一种跨编程语言和跨操作系统平台的远程调用技术。Web service是一个平台独立的,低耦合的,自包含的、基于可编程的web的应用程序,可使用开放的XML(标准通用标记语言下的一个子集)标准...

Retrofit(2.0)入门小错误 -- Could not locate ResponseBody xxx Tried: * retrofit.BuiltInConverters_已添加addconverterfactory 但是 could not locate respons-程序员宅基地

文章浏览阅读1w次。前言照例给出官网:Retrofit官网其实大家学习的时候,完全可以按照官网Introduction,自己写一个例子来运行。但是百密一疏,官网可能忘记添加了一句非常重要的话,导致你可能出现如下错误:Could not locate ResponseBody converter错误信息:Caused by: java.lang.IllegalArgumentException: Could not l_已添加addconverterfactory 但是 could not locate responsebody converter

一套键鼠控制Windows+Linux——Synergy在Windows10和Ubuntu18.04共控的实践_linux 18.04 synergy-程序员宅基地

文章浏览阅读1k次。一套键鼠控制Windows+Linux——Synergy在Windows10和Ubuntu18.04共控的实践Synergy简介准备工作(重要)Windows服务端配置Ubuntu客户端配置配置开机启动Synergy简介Synergy能够通过IP地址实现一套键鼠对多系统、多终端进行控制,免去了对不同终端操作时频繁切换键鼠的麻烦,可跨平台使用,拥有Linux、MacOS、Windows多个版本。Synergy应用分服务端和客户端,服务端即主控端,Synergy会共享连接服务端的键鼠给客户端终端使用。本文_linux 18.04 synergy

nacos集成seata1.4.0注意事项_seata1.4.0 +nacos 集成-程序员宅基地

文章浏览阅读374次。写demo的时候遇到了很多问题,记录一下。安装nacos1.4.0配置mysql数据库,新建nacos_config数据库,并根据初始化脚本新建表,使配置从数据库读取,可单机模式启动也可以集群模式启动,启动时 ./start.sh -m standaloneapplication.properties 主要是db部分配置## Copyright 1999-2018 Alibaba Group Holding Ltd.## Licensed under the Apache License,_seata1.4.0 +nacos 集成

iperf3常用_iperf客户端指定ip地址-程序员宅基地

文章浏览阅读833次。iperf使用方法详解 iperf3是一款带宽测试工具,它支持调节各种参数,比如通信协议,数据包个数,发送持续时间,测试完会报告网络带宽,丢包率和其他参数。 安装 sudo apt-get install iperf3 iPerf3常用的参数: -c :指定客户端模式。例如:iperf3 -c 192.168.1.100。这将使用客户端模式连接到IP地址为192.16..._iperf客户端指定ip地址

浮点性(float)转化为字符串类型 自定义实现和深入探讨C++内部实现方法_c++浮点数 转 字符串 精度损失最小-程序员宅基地

文章浏览阅读7.4k次。 写这个函数目的不是为了和C/C++库中的函数在性能和安全性上一比高低,只是为了给那些喜欢探讨函数内部实现的网友,提供一种从浮点性到字符串转换的一种途径。 浮点数是有精度限制的,所以即使我们在使用C/C++中的sprintf或者cout 限制,当然这个精度限制是可以修改的。比方在C++中,我们可以cout.precision(10),不过这样设置的整个输出字符长度为10,而不是特定的小数点后1_c++浮点数 转 字符串 精度损失最小

推荐文章

热门文章

相关标签