windows swin transformer训练自己的目标检测数据集_# do not use mmdet version fp16 fp16 = none-程序员宅基地

技术标签: 算法  python  深度学习  模型训练  pytorch  

主要是有几个地方的文件要修改一下

config/swin下的配置文件,我用的是mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py

img_scale主要是内存不够要改小

_base_ = [
    '../_base_/models/mask_rcnn_swin_fpn.py',
    # '../_base_/datasets/coco_instance.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

model = dict(
    backbone=dict(
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        ape=False,
        drop_path_rate=0.1,
        patch_norm=True,
        use_checkpoint=False
    ),
    neck=dict(in_channels=[96, 192, 384, 768]))

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=False),
    # dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='AutoAugment',
         policies=[
             [
                 dict(type='Resize',
                      # img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
                      #            (608, 1333), (640, 1333), (672, 1333), (704, 1333),
                      #            (736, 1333), (768, 1333), (800, 1333)],
                      img_scale=[(224,224)],
                      multiscale_mode='value',
                      keep_ratio=True)
             ],
             [
                 dict(type='Resize',
                      # img_scale=[(400, 1333), (500, 1333), (600, 1333)],
                      img_scale=[(224, 224)],
                      multiscale_mode='value',
                      keep_ratio=True),
                 dict(type='RandomCrop',
                      crop_type='absolute_range',
                      crop_size=(384, 600),
                      allow_negative_crop=True),
                 dict(type='Resize',
                      # img_scale=[(480, 1333), (512, 1333), (544, 1333),
                      #            (576, 1333), (608, 1333), (640, 1333),
                      #            (672, 1333), (704, 1333), (736, 1333),
                      #            (768, 1333), (800, 1333)],
                      img_scale=[(224, 224)],
                      multiscale_mode='value',
                      override=True,
                      keep_ratio=True)
             ]
         ]),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
    # dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))

optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
# runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
fp16 = None
# optimizer_config = dict(
#     type="DistOptimizerHook",
#     update_interval=1,
#     grad_clip=None,
#     coalesce=True,
#     bucket_size_mb=-1,
#     use_fp16=True,
# )

configs\_base_\default_runtime.py (如果下载了官方预训练模型)

checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
# 预训练模型的路径按需修改,不用预训练模型训练的结果会很差
#load_from = 'D:/Users/Downloads/Swin-Transformer-Object-Detection/model/mask_rcnn_swin_tiny_patch4_window7_1x.pth'
resume_from = None
workflow = [('train', 1)]

configs\_base_\datasets\coco_detection.py 数据集路径,格式和coco保持一致

dataset_type = 'CocoDataset'
# data_root = 'data/coco/'
data_root = 'F:/myproject/dataset/COCO_/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    # dict(type='LoadAnnotations', with_bbox=True),
    dict(type='LoadAnnotations', with_bbox=True,with_mask=False ,with_seg=False,poly2mask=False),
    dict(type='Resize', img_scale=(224, 224), keep_ratio=True),#img_scale=(1333, 800),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(224, 224), #img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')

configs\_base_\models\mask_rcnn_swin_fpn.py 注意修改类别和mask

# model settings
model = dict(
    type='MaskRCNN',
    pretrained=None,
    backbone=dict(
        type='SwinTransformer',
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        ape=False,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        use_checkpoint=False),
    neck=dict(
        type='FPN',
        in_channels=[96, 192, 384, 768],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            # 修改为自己的数据类数
            # num_classes=80,
            num_classes=20,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        # mask_roi_extractor=dict(
        #     type='SingleRoIExtractor',
        #     roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
        #     out_channels=256,
        #     featmap_strides=[4, 8, 16, 32]),
        # mask_head=dict(
        #     type='FCNMaskHead',
        #     num_convs=4,
        #     in_channels=256,
        #     conv_out_channels=256,
        # 修改为自己的数据类数
        #     num_classes=80,
        #     loss_mask=dict(
        #         type='CrossEntropyLoss', use_mask=True, loss_weight=1.0
        #     )
        # )
    ),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            # mask_size=28,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            # mask_thr_binary=0.5
        )))

mmdet\datasets\coco.py 替换自定义数据集的类别

class CocoDataset(CustomDataset):

    # CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    #            'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    #            'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    #            'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    #            'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    #            'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    #            'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    #            'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    #            'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    #            'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    #            'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    #            'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    #            'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    #            'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')

    CLASSES = (
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    )

mmdet\core\evaluation\class_names.py中的67行也要替换自定义数据集类别信息

# def coco_classes():
#     return [
#         'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
#         'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
#         'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
#         'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
#         'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
#         'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
#         'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
#         'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
#         'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
#         'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
#         'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
#         'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
#         'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
#     ]
def coco_classes():
    return [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/athrunsunny/article/details/124470524

智能推荐

使用JDBC连接数据库出现 The server time zone value ‘�й���׼ʱ��‘ is unrecognized or represents more than one解决方案_jdbc.properties timezone-程序员宅基地

文章浏览阅读553次。在 jdbc.properties 文件中的 url 后面加上 ?serverTimezone=UTC加入之前的jdbc.properties文件:user=rootpassword=12345678url=jdbc:mysql://localhost:3306/testdriverClass=com.mysql.cj.jdbc.Driver加入之后:user=rootpassword=12345678url=jdbc:mysql://localhost:3306/test?serv_jdbc.properties timezone

计算机图形学孔令德基础知识,计算机图形学基础教程孔令德答案-程序员宅基地

文章浏览阅读1.4k次。计算机图形学基础教程孔令德答案【篇一:大学计算机图形学课程设】息科学与工程学院课程设计任务书题目:小组成员:巴春华、焦国栋成员学号:专业班级:计算机科学与技术、2009级本2班课程:计算机图形学指导教师:燕孝飞职称:讲师完成时间: 2011年12 月----2011年 12 月枣庄学院信息科学与工程学院制2011年12 月20日课程设计任务书及成绩评定12【篇二:计算机动画】第一篇《计算机图形学》..._计算机图形学基础教程 孔令德 答案

python xlwings追加数据_大数据分析Python库xlwings提升Excel工作效率教程-程序员宅基地

文章浏览阅读1k次。原标题:大数据分析Python库xlwings提升Excel工作效率教程Excel在当今的企业中非常非常普遍。在AAA教育,我们通常建议出于很多原因使用代码,并且我们的许多数据科学课程旨在教授数据分析和数据科学的有效编码。但是,无论您偏爱使用大数据分析Python的程度如何,最终,有时都需要使用Excel来展示您的发现或共享数据。但这并不意味着仍然无法享受大数据分析Python的某些效率!实际上,..._xlwings通过索引添加数据

java8u211_jre864位u211-程序员宅基地

文章浏览阅读911次。iefans为用户提供的jre8 64位是针对64位windows平台而开发的java运行环境软件,全称为java se runtime environment 8,包括Java虚拟机、Java核心类库和支持文件,不包含开发工具--编译器、调试器和其它工具。jre需要辅助软件--JavaPlug-in--以便在浏览器中运行applet。本次小编带来的是jre8 64位官方版下载,版本小号u211版..._jre8是什么

kasp技术原理_KASP基因分型-程序员宅基地

文章浏览阅读5k次。KASP基因分型介绍KASP(Kompetitive Allele-Specific PCR),即竞争性等位基因特异性PCR,原理上与TaqMan检测法类似,都是基于终端荧光信号的读取判断,每孔反应都是采用双色荧光检测一个SNP位点的两种基因型,不同的SNP对应着不同的荧光信号。KASP技术与TaqMan法类似,它与TaqMan技术不同的是,它不需要每个SNP位点都合成特异的荧光引物,它基于独特的..._kasp是什么

华为p50预装鸿蒙系统,华为p50会不会预装鸿蒙系统_华为p50会预装鸿蒙系统吗-程序员宅基地

文章浏览阅读154次。华为现在比较火的还真就是新开发的鸿蒙系统了,那么在即将上市的华为p50手机上会不会预装鸿蒙系统呢?接下来我们就来一起了解一下华为官方发布的最新消息吧。1.华为p50最新消息相信大家都知道,随着华为鸿蒙OS系统转正日期临近,似乎全网的花粉们都在关注华为鸿蒙OS系统优化、生态建设等等,直接忽略了不断延期发布的华为P50手机,如今华为P50系列手机终于传来了最新的好消息,在经过一系列方案修改以后,终于被..._华为手机p50直接预装鸿蒙系统

随便推点

python用什么软件编程好-初学python编程,有哪些不错的软件值得一用?-程序员宅基地

文章浏览阅读2.1k次。Python编程的软件其实许多,作为一门面向大众的编程言语,许多修正器都有对应的Python插件,当然,也有特地的PythonIDE软件,下面我简单引见几个不错的Python编程软件,既有修正器,也有IDE,感兴味的朋友可以本人下载查验一下:1.VSCode:这是一个轻量级的代码修正器,由微软规划研发,免费、开源、跨途径,轻盈活络,界面精练,支撑常见的自动补全、语法提示、代码高亮、Git等功用,插..._python入门学什么好

pytorch一步一步在VGG16上训练自己的数据集_torch vgg训练自己的数据集-程序员宅基地

文章浏览阅读3.2w次,点赞30次,收藏307次。准备数据集及加载,ImageFolder在很多机器学习或者深度学习的任务中,往往我们要提供自己的图片。也就是说我们的数据集不是预先处理好的,像mnist,cifar10等它已经给你处理好了,更多的是原始的图片。比如我们以猫狗分类为例。在data文件下,有两个分别为train和val的文件夹。然后train下是cat和dog两个文件夹,里面存的是自己的图片数据,val文件夹同train。这样我们的..._torch vgg训练自己的数据集

毕业论文管理系统设计与实现(论文+源码)_kaic_论文系统设计法-程序员宅基地

文章浏览阅读968次。论文+系统+远程调试+重复率低+二次开发+毕业设计_论文系统设计法

在python2与python3中转义字符_Python 炫技操作:五种 Python 转义表示法-程序员宅基地

文章浏览阅读134次。1. 为什么要有转义?ASCII 表中一共有 128 个字符。这里面有我们非常熟悉的字母、数字、标点符号,这些都可以从我们的键盘中输出。除此之外,还有一些非常特殊的字符,这些字符,我通常很难用键盘上的找到,比如制表符、响铃这种。为了能将那些特殊字符都能写入到字符串变量中,就规定了一个用于转义的字符 \ ,有了这个字符,你在字符串中看的字符,print 出来后就不一定你原来看到的了。举个例子>..._pytyhon2、python3对%转义吗

java jar 文件 路径问题_「问答」解决jar包运行时相对路径问题-程序员宅基地

文章浏览阅读1.3k次。我这几天需要做一个Java程序,需要通过jar的形式运行,还要生成文件。最终这个程序是要给被人用的,可能那个用的人还不懂代码。于是我面临一个问题:生成的文件一定不能存绝对路径。刚开始我想得很简单,打绝对路径改成相对路径不就行了吗?于是有了这样的代码:String path = "../test.txt";File file = new File(path);……这个写法本身并没有问题,直接运行代码..._jar启动文件路径中存在!

微信读书vscode插件_曾经我以为 VSCode 是程序员专属的工具,直到发现了这些……...-程序员宅基地

文章浏览阅读598次。如果你知道 VSCode,一说起它,你可能第一个想到的就是把它当做一个代码编辑器,而它的界面应该可能大概率是这样的——如果你恰好又是个程序员,那你可能经常会用到它,不管是 Python、JS 还是 C++ 等各种语言对应的文件,都可以用它来进行简单的编辑和整理,甚至是运行和 debug......但是今天要讲的显然不是这些,经过小美的多方研究,发现了即使是对于大多数并不了解 VSCode,也完全不..._vscode weixin read

推荐文章

热门文章

相关标签