mmdetection自定义模型_standardroihead-程序员宅基地

技术标签: python  机器学习  深度学习  mmdetection  

我们基本上将模型组件分为 5 种类型。

  1. 骨干:通常是一个 FCN 网络来提取特征图,例如 ResNet、MobileNet。
  2. 颈部:骨干和头部之间的组件,例如FPN、PAFPN。
  3. head:特定任务的组件,例如 bbox 预测和掩码预测。
  4. roiextractor:从特征图中提取RoI特征的部分,例如RoI Align。
  5. loss:用于计算损失的 head 组件,例如FocalLoss、L1Loss 和 GHMLoss。

开发新组件

添加新的主干
这里我们以 MobileNet 为例展示如何开发新组件。

1**.定义一个新的主干(例如MobileNet)**
创建一个新文件mmdet/models/backbones/mobilenet.py。

import torch.nn as nn

from ..builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

  1. 导入模块
    您可以将以下行添加到 mmdet/models/backbones/init.py
from .mobilenet import MobileNet

或者添加

custom_imports = dict(
    imports=['mmdet.models.backbones.mobilenet'],
    allow_failed_imports=False)

到配置文件以避免修改原始代码。

3. 在你的配置文件中使用主干
model = dict(

backbone=dict(
type=‘MobileNet’,
arg1=xxx,
arg2=xxx),

添加新的脖子

1. 定义颈部(例如 PAFPN)
创建一个新文件mmdet/models/necks/pafpn.py。

from ..builder import NECKS

@NECKS.register_module()
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

  1. 导入模块
    您可以将以下行添加到mmdet/models/necks/init.py,
from .pafpn import PAFPN

或者添加

custom_imports = dict(
    imports=['mmdet.models.necks.pafpn.py'],
    allow_failed_imports=False)

到配置文件,避免修改原始代码。

3.修改配置文件

neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

添加新的头

这里我们以双头 R-CNN为例展示如何开发一个新的头部,如下所示。
首先,在mmdet/models/roi_heads/bbox_heads/double_bbox_head.py. 双头 R-CNN 实现了一个新的 bbox 头用于对象检测。要实现一个bbox head,基本上我们需要实现新模块的三个功能如下。

from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead

@HEADS.register_module()
class DoubleConvFCBBoxHead(BBoxHead):
    r"""Bbox head used in Double-Head R-CNN

                                      /-> cls
                  /-> shared convs ->
                                      \-> reg
    roi features
                                      /-> cls
                  \-> shared fc    ->
                                      \-> reg
    """  # noqa: W605

    def __init__(self,
                 num_convs=0,
                 num_fcs=0,
                 conv_out_channels=1024,
                 fc_out_channels=1024,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        kwargs.setdefault('with_avg_pool', True)
        super(DoubleConvFCBBoxHead, self).__init__(**kwargs)


    def forward(self, x_cls, x_reg):

其次,如有必要,实施新的 RoI Head。我们计划DoubleHeadRoIHead从StandardRoIHead. 我们可以发现aStandardRoIHead已经实现了以下功能。

import torch

from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from ..builder import HEADS, build_head, build_roi_extractor
from .base_roi_head import BaseRoIHead
from .test_mixins import BBoxTestMixin, MaskTestMixin


@HEADS.register_module()
class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
    """Simplest base roi head including one bbox head and one mask head.
    """

    def init_assigner_sampler(self):

    def init_bbox_head(self, bbox_roi_extractor, bbox_head):

    def init_mask_head(self, mask_roi_extractor, mask_head):


    def forward_dummy(self, x, proposals):


    def forward_train(self,
                      x,
                      img_metas,
                      proposal_list,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):

    def _bbox_forward(self, x, rois):

    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
                            img_metas):

    def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
                            img_metas):

    def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):


    def simple_test(self,
                    x,
                    proposal_list,
                    img_metas,
                    proposals=None,
                    rescale=False):
        """Test without augmentation."""


Double Head 的修改主要是在 bbox_forward 逻辑中,它继承了StandardRoIHead. 在 中mmdet/models/roi_heads/double_roi_head.py,我们实现了新的 RoI Head,如下所示:

from ..builder import HEADS
from .standard_roi_head import StandardRoIHead


@HEADS.register_module()
class DoubleHeadRoIHead(StandardRoIHead):
    """RoI head for Double Head RCNN

    https://arxiv.org/abs/1904.06493
    """

    def __init__(self, reg_roi_scale_factor, **kwargs):
        super(DoubleHeadRoIHead, self).__init__(**kwargs)
        self.reg_roi_scale_factor = reg_roi_scale_factor

    def _bbox_forward(self, x, rois):
        bbox_cls_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs], rois)
        bbox_reg_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs],
            rois,
            roi_scale_factor=self.reg_roi_scale_factor)
        if self.with_shared_head:
            bbox_cls_feats = self.shared_head(bbox_cls_feats)
            bbox_reg_feats = self.shared_head(bbox_reg_feats)
        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            bbox_feats=bbox_cls_feats)
        return bbox_results

最后,用户需要添加模块 mmdet/models/bbox_heads/init.py,mmdet/models/roi_heads/init.py从而相应的注册表可以找到并加载它们。

或者,用户可以添加

custom_imports=dict(
    imports=['mmdet.models.roi_heads.double_roi_head', 'mmdet.models.bbox_heads.double_bbox_head'])

到配置文件并实现相同的目标。

双头R-CNN的配置文件如下

_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
model = dict(
    roi_head=dict(
        type='DoubleHeadRoIHead',
        reg_roi_scale_factor=1.3,
        bbox_head=dict(
            _delete_=True,
            type='DoubleConvFCBBoxHead',
            num_convs=4,
            num_fcs=2,
            in_channels=256,
            conv_out_channels=1024,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))

从MMDetection 2.0开始,配置系统支持继承配置,让用户可以专注于修改。双头R-CNN主要使用了一个新的DoubleHeadRoIHead和一个新的 DoubleConvFCBBoxHead,参数根据__init__每个模块的功能设置。

添加新的损失

假设您想MyLoss为边界框回归添加一个新的损失为。要添加新的损失函数,用户需要在mmdet/models/losses/my_loss.py. 装饰weighted_loss器使损失能够为每个元素加权。

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

然后用户需要将它添加到mmdet/models/losses/init.py.

from .my_loss import MyLoss, my_loss

或者,您可以添加

custom_imports=dict(
    imports=['mmdet.models.losses.my_loss'])

到配置文件并实现相同的目标。
要使用它,请修改该loss_xxx字段。由于 MyLoss 是用于回归的,所以需要修改loss_bboxhead 中的字段。

loss_bbox=dict(type='MyLoss', loss_weight=1.0))
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/w1520039381/article/details/121810120

智能推荐

Linux查看登录用户日志_怎么记录linux设备 发声的登录和登出-程序员宅基地

文章浏览阅读8.6k次。一、Linux记录用户登录信息文件1  /var/run/utmp----记录当前正在登录系统的用户信息;2  /var/log/wtmp----记录当前正在登录和历史登录系统的用户信息;3  /var/log/btmp:记录失败的登录尝试信息。二、命令用法1.命令last,lastb---show a listing of la_怎么记录linux设备 发声的登录和登出

第四章笔记:遍历--算法学中的万能钥匙-程序员宅基地

文章浏览阅读167次。摘要:1. 简介 2. 公园迷宫漫步 3. 无线迷宫与最短(不加权)路径问题 4. 强连通分量1. 简介在计算机科学裡,树的遍历(也称为树的搜索)是圖的遍歷的一种,指的是按照某种规则,不重复地访问某种樹的所有节点的过程。具体的访问操作可能是检查节点的值、更新节点的值等。不同的遍历方式,其访问节点的顺序是不一样的。两种著名的基本遍历策略:深度优先搜索(DFS) 和 广度优先搜索(B...

【案例分享】使用ActiveReports报表工具,在.NET MVC模式下动态创建报表_activereports.net 实现查询报表功能-程序员宅基地

文章浏览阅读591次。提起报表,大家会觉得即熟悉又陌生,好像常常在工作中使用,又似乎无法准确描述报表。今天我们来一起了解一下什么是报表,报表的结构、构成元素,以及为什么需要报表。什么是报表简单的说:报表就是通过表格、图表等形式来动态显示数据,并为使用者提供浏览、打印、导出和分析的功能,可以用公式表示为:报表 = 多样的布局 + 动态的数据 + 丰富的输出报表通常包含以下组成部分:报表首页:在报表的开..._activereports.net 实现查询报表功能

Ubuntu18.04 + GNOME xrdp + Docker + GUI_docker xrdp ubuntu-程序员宅基地

文章浏览阅读6.6k次。最近实验室需要用Cadence,这个软件的安装非常麻烦,每一次配置都要几个小时,因此打算把Cadence装进Docker。但是Cadence运行时需要GUI,要对Docker进行一些配置。我们实验室的服务器运行的是Ubuntu18.04,默认桌面GNOME,Cadence装进Centos的Docker。安装Ubuntu18.04服务器上安装Ubuntu18.04的教程非常多,在此不赘述了安装..._docker xrdp ubuntu

iOS AVFoundation实现相机功能_ios avcapturestillimageoutput 兼容性 ios17 崩溃-程序员宅基地

文章浏览阅读1.8k次,点赞2次,收藏2次。首先导入头文件#import 导入头文件后创建几个相机必须实现的对象 /** * AVCaptureSession对象来执行输入设备和输出设备之间的数据传递 */ @property (nonatomic, strong) AVCaptureSession* session; /** * 输入设备 */_ios avcapturestillimageoutput 兼容性 ios17 崩溃

Oracle动态性能视图--v$sysstat_oracle v$sysstat视图-程序员宅基地

文章浏览阅读982次。按照OracleDocument中的描述,v$sysstat存储自数据库实例运行那刻起就开始累计全实例(instance-wide)的资源使用情况。 类似于v$sesstat,该视图存储下列的统计信息:1>.事件发生次数的统计(如:user commits)2>._oracle v$sysstat视图

随便推点

Vue router报错:NavigationDuplicated {_name: "NavigationDuplicated", name: "NavigationDuplicated"}的解决方法_navigationduplicated {_name: 'navigationduplicated-程序员宅基地

文章浏览阅读7.6k次,点赞2次,收藏9次。我最近做SPA项目开发动态树的时候一直遇到以下错误:当我点击文章管理需要跳转路径时一直报NavigationDuplicated {_name: “NavigationDuplicated”, name: “NavigationDuplicated”}这个错误但是当我点击文章管理后,路径跳转却是成功的<template> <div> 文章管理页面 <..._navigationduplicated {_name: 'navigationduplicated', name: 'navigationduplic

Webrtc回声消除模式(Aecm)屏蔽舒适噪音(CNG)_webrtc aecm 杂音-程序员宅基地

文章浏览阅读3.9k次。版本VoiceEngine 4.1.0舒适噪音生成(comfort noise generator,CNG)是一个在通话过程中出现短暂静音时用来为电话通信产生背景噪声的程序。#if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS)static const EcModes kDefaultEcMode = kEcAecm;#elsestati..._webrtc aecm 杂音

医学成像原理与图像处理一:概论_医学成像与图像处理技术知识点总结-程序员宅基地

文章浏览阅读6.3k次,点赞9次,收藏19次。医学成像原理与图像处理一:概论引言:本系列博客为医学成像原理与图像处理重要笔记,由于是手写,在此通过扫描录入以图片的形式和电子版增补内容将其进行组织和共享。前半部分内容为图像处理基础内容,包括图像的灰度级处理、空间域滤波、频率域滤波、图像增强和分割等;后半部分内容为医学影象技术,包括常规胶片X光机、CR、DR、CT、DSA等X射线摄影技术、超声成像技术、磁共振成像(MRI)技术等。本篇主要内容是概论。_医学成像与图像处理技术知识点总结

notepad++ v8.5.3 安装插件,安装失败怎么处理?下载进度为0怎么处理?_nodepa++-程序员宅基地

文章浏览阅读591次,点赞13次,收藏10次。notepad++ v8.5.3 安装插件,下载进度为0_nodepa++

hive某个字段中包括\n(和换行符冲突)_hive sql \n-程序员宅基地

文章浏览阅读2.1w次。用spark执行SQL保存到Hive中: hiveContext.sql(&quot;insert overwrite table test select * from aaa&quot;)执行完成,没报错,但是核对结果的时候,发现有几笔数据超出指定范围(实际只包含100/200)最终排查到是ret_pay_remark 字段包含换行符,解决方案:执行SQL中把特殊字符替换掉regexp_replace(..._hive sql \n

印象笔记05:如何打造更美的印象笔记超级笔记_好的印象笔记怎么做的-程序员宅基地

文章浏览阅读520次,点赞10次,收藏8次。印象笔记05:如何打造更美的印象笔记超级笔记本文介绍印象笔记的具体使用,如何打造更美更实用的笔记。首先想要笔记更加好看和实用,我认为要使用超级笔记。所谓超级笔记就是具有很多便捷功能的笔记。_好的印象笔记怎么做的

推荐文章

热门文章

相关标签