YOLOX改进之添加ASFF_asff代码_你的陈某某的博客-程序员秘密

技术标签: YOLOX  计算机视觉  目标检测  人工智能  pytorch  

文章内容:如何在YOLOX官网代码中添加ASFF模块

环境:pytorch1.8

修改内容

(1)在PAFPN尾部添加ASFF模块(YOLOX-s等版本)

(2)在FPN尾部添加ASFF模块(YOLOX-Darknet53版本)

参考链接

论文链接https://arxiv.org/pdf/1911.09516v2.pdf

ASFF原理及代码参考https://blog.csdn.net/weixin_44119362/article/details/114289607

示意图如下
在这里插入图片描述

使用方法:直接在PAFPN或FPN尾部添加即可(可自动进行维度匹配,不需要修改)

代码修改过程

1、在YOLOXS版本的PAFPN后添加ASFF模块

(注意:这里是PAFPN该版本用于YOLOv5版的PAFPN中,不能用于YOLOv3的FPN)

步骤一:在YOLOX-main/yolox/models文件夹下创建ASFF.py文件,内容如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))


class ASFF(nn.Module):
    def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True):
        """
        multiplier should be 1, 0.5
        which means, the channel of ASFF can be 
        512, 256, 128 -> multiplier=0.5
        1024, 512, 256 -> multiplier=1
        For even smaller, you need change code manually.
        """
        super(ASFF, self).__init__()
        self.level = level
        self.dim = [int(1024*multiplier), int(512*multiplier),
                    int(256*multiplier)]
        # print(self.dim)
        
        self.inter_dim = self.dim[self.level]
        if level == 0:
            self.stride_level_1 = Conv(int(512*multiplier), self.inter_dim, 3, 2)
                
            self.stride_level_2 = Conv(int(256*multiplier), self.inter_dim, 3, 2)
                
            self.expand = Conv(self.inter_dim, int(
                1024*multiplier), 3, 1)
        elif level == 1:
            self.compress_level_0 = Conv(
                int(1024*multiplier), self.inter_dim, 1, 1)
            self.stride_level_2 = Conv(
                int(256*multiplier), self.inter_dim, 3, 2)
            self.expand = Conv(self.inter_dim, int(512*multiplier), 3, 1)
        elif level == 2:
            self.compress_level_0 = Conv(
                int(1024*multiplier), self.inter_dim, 1, 1)
            self.compress_level_1 = Conv(
                int(512*multiplier), self.inter_dim, 1, 1)
            self.expand = Conv(self.inter_dim, int(
                256*multiplier), 3, 1)

        # when adding rfb, we use half number of channels to save memory
        compress_c = 8 if rfb else 16
        self.weight_level_0 = Conv(
            self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(
            self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(
            self.inter_dim, compress_c, 1, 1)

        self.weight_levels = Conv(
            compress_c*3, 3, 1, 1)
        self.vis = vis

    def forward(self, x): #l,m,s
        """
        # 
        256, 512, 1024
        from small -> large
        """
        x_level_0=x[2] #最大特征层
        x_level_1=x[1] #中间特征层
        x_level_2=x[0] #最小特征层

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_downsampled_inter = F.max_pool2d(
                x_level_2, 3, stride=2, padding=1)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
        elif self.level == 1:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(
                level_0_compressed, scale_factor=2, mode='nearest')
            level_1_resized = x_level_1
            level_2_resized = self.stride_level_2(x_level_2)
        elif self.level == 2:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(
                level_0_compressed, scale_factor=4, mode='nearest')
            x_level_1_compressed = self.compress_level_1(x_level_1)
            level_1_resized = F.interpolate(
                x_level_1_compressed, scale_factor=2, mode='nearest')
            level_2_resized = x_level_2

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)

        levels_weight_v = torch.cat(
            (level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] +\
            level_1_resized * levels_weight[:, 1:2, :, :] +\
            level_2_resized * levels_weight[:, 2:, :, :]

        out = self.expand(fused_out_reduced)

        if self.vis:
            return out, levels_weight, fused_out_reduced.sum(dim=1)
        else:
            return out

步骤二:在YOLOX-main/yolox/models/yolo_pafpn.py中调用ASFF模块

(1)导入

from .ASFF import ASFF

(2)在init中实例化

        # ############ 2、实例化ASFF
        self.asff_1 = ASFF(level = 0, multiplier = width)
        self.asff_2 = ASFF(level = 1, multiplier = width)
        self.asff_3 = ASFF(level = 2, multiplier = width)

    def forward(self, input):

(3)直接在PAFPN输出outputs后接上ASFF模块

        outputs = (pan_out2, pan_out1, pan_out0)

        # asff
        pan_out0 = self.asff_1(outputs)
        pan_out1 = self.asff_2(outputs)
        pan_out2 = self.asff_3(outputs)
        outputs = (pan_out2, pan_out1, pan_out0)
        
        return outputs

2、在YOLOX-Darknet53FPN后添加ASFF模块

(注意:这里是用于YOLOv3的FPN)

步骤一:在YOLOX-main/yolox/models文件夹下创建ASFF_darknet.py文件,内容如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_blocks import BaseConv

# 输入的是从FPN得到的特征
# 输出是要给head的特征,与之前不变,注意单个ASFF只能输出一个特征,level=0对应最底层的特征,这里是512*20*20,尺度大小 level_0 < level_1 < level_2

class ASFF(nn.Module):
    def __init__(self, level, rfb=False, vis=False):
        super(ASFF, self).__init__()
        self.level = level
        self.dim = [512, 256, 128]
        self.inter_dim = self.dim[self.level]
        if level==0:
            self.stride_level_1 = self._make_cbl(256, self.inter_dim, 3, 2)
            self.stride_level_2 = self._make_cbl(128, self.inter_dim, 3, 2)
            self.expand = self._make_cbl(self.inter_dim, 512, 3, 1)  # 输出是要给head的特征,与之前不变512-512
        elif level==1:
            self.compress_level_0 = self._make_cbl(512, self.inter_dim, 1, 1)
            self.stride_level_2 = self._make_cbl(128, self.inter_dim, 3, 2)
            self.expand = self._make_cbl(self.inter_dim, 256, 3, 1)  # 输出是要给head的特征,与之前不变256-256
        elif level==2:
            self.compress_level_0 = self._make_cbl(512, self.inter_dim, 1, 1)
            self.compress_level_1 = self._make_cbl(256, self.inter_dim, 1, 1)
            self.expand = self._make_cbl(self.inter_dim, 128, 3, 1)  # 输出是要给head的特征,与之前不变128-128

        compress_c = 8 if rfb else 16  #when adding rfb, we use half number of channels to save memory

        self.weight_level_0 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = self._make_cbl(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
        self.vis= vis


    def _make_cbl(self, _in, _out, ks, stride):
        return BaseConv(_in, _out, ks, stride, act="lrelu")

    def forward(self, x_level_0, x_level_1, x_level_2):   # 输入3个维度(512*20*20,256*40*40,128*80*80),输出也是
        if self.level==0:
            level_0_resized = x_level_0  # (512*20*20)
            level_1_resized = self.stride_level_1(x_level_1)  # (256*40*40->512*20*20)

            level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1)  # (128*80*80->128*40*40)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)  # (128*40*40->512*20*20)

        elif self.level==1:
            level_0_compressed = self.compress_level_0(x_level_0)  # (512*20*20->256*20*20)
            level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')  # (256*20*20->256*40*40)
            level_1_resized =x_level_1  # (256*40*40)
            level_2_resized =self.stride_level_2(x_level_2)  # (128*80*80->256*40*40)
        elif self.level==2:
            level_0_compressed = self.compress_level_0(x_level_0)  # (512*20*20->128*20*20)
            level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')  # (128*20*20->128*80*80)
            level_1_compressed = self.compress_level_1(x_level_1)  # (256*40*40->128*40*40)
            level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')  # (128*40*40->128*80*80)
            level_2_resized =x_level_2  # (128*80*80)

        level_0_weight_v = self.weight_level_0(level_0_resized)  # 
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)
        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\
                            level_1_resized * levels_weight[:,1:2,:,:]+\
                            level_2_resized * levels_weight[:,2:,:,:]

        out = self.expand(fused_out_reduced)

        if self.vis:
            return out, levels_weight, fused_out_reduced.sum(dim=1)
        else:
            return out

步骤二:在YOLOX-main/yolox/models/yolo_fpn.py中调用ASFF模块

(1)导入

from .ASFF_darknet import ASFF  ### 1、导入

(2)实例化ASFF对象

        #######################  2、实例化ASFF
        self.assf_5 = ASFF(level = 0)
        self.assf_4 = ASFF(level = 1)
        self.assf_3 = ASFF(level = 2)
        ########################

    def _make_cbl(self, _in, _out, ks):
        return BaseConv(_in, _out, ks, stride=1, act="lrelu")

(3)在outputs后直接添加asff

        outputs = (out_dark3, out_dark4, x0)  # 特征图尺度逐渐变小(128,256,512)  ### 该行为初始的FPN输出,使用ASFF则注释掉

        ################################################
        # 3、对FPN特征金字塔进行ASFF操作,注释掉原FPN输出outpus
        
        out_assf_5 = self.assf_5(x0, out_dark4, out_dark3)
        out_assf_4 = self.assf_4(x0, out_dark4, out_dark3)
        out_assf_3 = self.assf_3(x0, out_dark4, out_dark3)

        outputs = (out_assf_3, out_assf_4, out_assf_5)
        #################################################
        return outputs

效果:根据个人数据集而定。对我的数据集没变化。

权重大小变化:yoloxs(68.8M->110M)

速度变化:有所下降

上述代码链接
链接:https://pan.baidu.com/s/1ykfb-YHpJaLj4sQpMsCIKw
提取码:qrvg

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_45679938/article/details/122354725

智能推荐

Hit Refresh读书摘要_hit fresh_yushulx的博客-程序员秘密

Hit Refresh - 刷新:重新发现商业与未来从海得拉巴到雷德蒙德在之前的一个世代,微软可能只有一个竞争对手,即IBM(国际商业机器公司)。但在遥遥领先所有对手多年之后,情况发生了变化,然而并不是朝着更好的方向发展——创新被官僚主义所取代,团队协作被内部政治所取代。我们落后了。在2014年2月被任命为微软第三任首席执行官时,我对公司员工表示,重塑企业文化将是我的首要任务。我告诉他们,我将不遗余

东南大学信号与系统matlab实践,信号与系统MATLAB实践_同步齿科-方的博客-程序员秘密

本书以信号与系统、数字信号处理等课程相关内容作为应用背景,结合MATLAB工具,介绍如何应用计算机技术解决工程实践中遇到的问题。全书共分八章。第一章为MATLAB简介,包括基本计算、作图语句及其系统帮助的使用。第二、三、四、五章则分别以信号分析、系统分析以及系统对信号响应分析等信号与系统课程相关内容为主线,介绍了如何用MATLAB解决相关的问题。第六章结合通信中的调制解调以及取样定理,应用计算机...

计算机网络常用通讯方式,通信方式_weixin_39822673的博客-程序员秘密

对于点对点之间的通信,按照消息传送的方向与时间关系,通信方式可分为单工通信、半双工通信及全双工通信三种。中文名通信方式外文名Communication学科领域工程技术通信方式单工通信编辑语音单工通信(Simplex Communication)是指消息只能单方向传输的工作方式。在单工通信中,通信的信道是单向的,发送端与接收端也是固定的,即发送端只能发送信息,不能接收信息;接收端只能...

PAT甲级真题-1007 Maximum Subsequence Sum详解优化_pat甲级1007测试点2,6_高冷小伙的博客-程序员秘密

1007 Maximum Subsequence Sum题目链接https://pintia.cn/problem-sets/994805342720868352/problems/994805514284679168解题思路1.一开始的思路好像是叫“在线处理法”;但这种方法能过大部分是因为题目规定了全为负值时取0的特殊情况;而且,无法过值为非正数的情况;for(int i=1;i&lt;=n;i++){ if(sum+Num[i]&lt;0){ sum=0; lnum=0;

php工程师如何进行区块链以太坊的开发_weixin_33979745的博客-程序员秘密

以太坊是备受关注的区块链,它基于密码学技术和P2P通信技术 构建了一个去中心化的平台,所有的交易同步保存在每个节点中, 通过将区块单向级联成链,以太坊有效的保证了交易的不可篡改:智能合约平台以太坊是第一个实现了虚拟机的区块链,因此为智能合约 - Smart Contract - 的运行提供了良好的支持环境。也正因为这个原因,以太坊被称为区块链 2.0,以区别于比特币代表的以数字加密货币为核心...

剑指Offer系列编程题详解全集_从流域到海域的博客-程序员秘密

剑指Offer系列是一本国内互联网公司计算机、软件、测试、运维等方向招聘笔试及面试经常会考的编程题合集,一共67道题,其中部分题目与LeetCode上的题目相一致,题目的难易度比较适中,有同名图书出版。该博客Github链接指向的是牛客网的剑指Offer系列编程题解法的解法repository。牛客网的剑指Offer和原书相比,只是题目顺序不一样,其余都一样。博主放在Github的代码也全都加了...

随便推点

Linux那些事儿之我是U盘(46)迷雾重重的Bulk传输(四)_iteye_18509的博客-程序员秘密

在讲数据传输阶段之前,先解决刚才的历史遗留问题. usb_stor_bulk_transfer_buf()中,429行,有一个很有趣的函数interpret_urb_result()被调用.这个函数同样来自drivers/usb/storage/transport.c:277 /* 278 * Interpret the results of a URB transfer 279 * 280...

vs2012中的sdf文件出错问题_vs2012 sdf数据库_smallmebigdream的博客-程序员秘密

不知道是什么时候出现的这个错误,一直没有去管它,打开一个以前建立的工程都会出现这个错误。而且很多的查找代码功能没有了,使得vs用起来十分的不方便。这里在网上找了一种解决方法解决了。情景再现:1.创建工程时创建c++浏览数据库文件xxx.sdf时发生错误Intellisense和浏览信息将不能用于c++项目2.打开工程时创建或打

【算法导论-36】并查集(Disjoint Set)具体解释_aoe41606的博客-程序员秘密

WiKiDisjoint是“不相交”的意思。Disjoint Set高效地支持集合的合并(Union)和集合内元素的查找(Find)两种操作,所以Disjoint Set中文翻译为并查集。 就《算法导论》21章来讲,主要设计这几个知识点:  用并查集计算图的连通区域;...

android dimensions.xml,Android的布局XML文件应该从styles.xml_weixin_39867142的博客-程序员秘密

派生视图的所有属性这是我layout.xml派生从styles.xml所有视图参数Android的布局XML文件应该从styles.xmlandroid:layout_width="match_parent"style="@style/DataContainer"&gt;android:id="@id/df_flight_logo"style="@style/FlightLOGO"/&gt;and...

Memory内存种类大全与简介_memory的分类_Ryan_瑞安的博客-程序员秘密

(原文:https://tech.hqew.com/news_1538989)根据组成元件的不同,RAM内存又分为以下十八种:  01.DRAM(Dynamic RAM,动态随机存取存储器):  这是最普通的RAM,一个电子管与一个电容器组成一个位存储单元,DRAM将每个内存位作为一个电荷保存在位存储单元中,用电容的充放电来做储存动作,但因电容本身有漏电问题,因此必须每几微秒就要刷新一...

记不住的坑(二)Element UI中upload组件上传docx文件缩略图不显示的问题_element 文件缩略图_画龍丶的博客-程序员秘密

   前言:最近写项目的时候碰到了upload上传以及回显的问题,这个项目比较复杂一个破upload组件都拆成三个组件来写,上传回显的时候带着一堆数据三层组件到处跑写起来属实头大。(个人认为套两层不就ok了嘛- -)图片上传完毕有图片回显,如图。可以显示的原因为他的url就是可以显示的。反过来我们看docx文件上传完之后,因为docx的url是一个下载地址,所以不支持回显。  解决方法:既然他url不支持回显,那我们就换一个固定的url进行绑定。首先找个图,把url引入到文件里。利用监听的方法改变

推荐文章

热门文章

相关标签