MMLAB学习-MMCLS项目-训练自己的任务_mmcls 怎么引入自己的类_dzm1204的博客-程序员宅基地

技术标签: 学习  python  # MMLAB  pytorch  

生成完整的配置文件

一般做法可以考虑继承base然后一个个实现模块,也可以先整体跑通在一个个修改,一般采用第二种 方法

在这里插入图片描述

先复制resnet配置文件的绝对路径

在这里插入图片描述

找到tools下的train.py就是入口函数

在这里插入图片描述

这里的路径是默认参数,所以直接给路径复制到参数设置里面去

在这里插入图片描述

在这里插入图片描述

参数配置完可以先试着跑一下,虽然会报错但是也会生成一个文件如下

在这里插入图片描述

会生成这个配置文件,接下来就可以直接复制到pycharm中改这个配置文件

在这里插入图片描述

重命名后,复制到pycharm打开,里面就是生成的自己的网络结构,一开始里面都一些默认配置,我们需要修改一些配置项,最多修改的就是输出大小修改成自己分类的类别数,和指定自己训练集、验证集、测试集的文件路径,工作空间路径以及迭代多少次保存模型和日志文件。修改后一般可以直接跑通

# 数据集存放在MMLAB\mmclassification-master\mmcls\data\flower_data\下
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        # 输出层改成102,因为是102分类
        num_classes=1000,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5)))
dataset_type = 'ImageNet'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    # 显存不够的情况下这里可以改小一点
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
# 指定数据:第一种方法是根据数据所在的文件夹去指定的
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type='ImageNet',
        # data_prefix='data/imagenet/train',
        # 指定自己的训练集的文件夹路径
        data_prefix='../mmcls/data/flower_data/train',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='RandomResizedCrop', size=224),
            dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='ToTensor', keys=['gt_label']),
            dict(type='Collect', keys=['img', 'gt_label'])
        ]),
    val=dict(
        type='ImageNet',
        # 指定自己验证集的文件夹路径
        data_prefix='../mmcls/data/flower_data/vaild',
        # ann_file='data/imagenet/meta/val.txt', #这句话如果注释掉则就把文件夹的名字当成类别,如果传了标注文件就以标注文件如主
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='Resize', size=(256, -1)),
            dict(type='CenterCrop', crop_size=224),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ]),
    test=dict(
        type='ImageNet',
        # 还没有弄测试集拿验证集先顶替
        data_prefix='../mmcls/data/flower_data/vaild',
        # ann_file='data/imagenet/meta/val.txt',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='Resize', size=(256, -1)),
            dict(type='CenterCrop', crop_size=224),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ]))
# 默认多次做一次评估
evaluation = dict(interval=1, metric='accuracy')
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=100)
# 间隔二十保存一次
checkpoint_config = dict(interval=50)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
# 默认工作路径  意思就是你保存的模型和保存的日志最后在存在哪
work_dir = './work_dirs/resnet18_8xb32_in1k'
gpu_ids = [0]

构建自己的数据集

用的方式二,所有的数据集都在一个文件夹没有分类。这时候就需要用标签来区分。标签格式: 图片名字 :图片标签

在这里插入图片描述

生成了这样的标签格式后,就需要再写一个数据处理文件,位置在mmcls/datasets下新建一个文件,这里是my_filelist.py

在这里插入图片描述

这里仿照imagenet.py,主要重写下面这个类,写入自己的类

在这里插入图片描述

下一步需要在_init_.py导入自己刚刚写的类

在这里插入图片描述

再去配置文件中去改

在这里插入图片描述

val做同样修改

在这里插入图片描述

修改后就可以开始跑自己的任务了

在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_45893652/article/details/127855891

智能推荐

java线程 kill linux_新手程序员登录服务器杀进程!高级:你别再瞎Kill进程服务了...-程序员宅基地

免费无套路分享 100G Java 视频、pdf 面试学习资料获取方式:【关注 + 转发】后,私信我,回复关键字【666】,即可免费无套路获取哦~以下是资源的部分目录以及内容截图:干货较多,这里仅仅贴出了部分哟~重要的事再说一遍,获取方式:【关注 + 转发】后,私信我,回复关键字【666】,即可免费无套路获取哦~正文开始,前言我们都知道,kill在linux系统中是用于杀死进程。kill pid ..._java 杀linux进程

actionSupport类-程序员宅基地

actionSupport类该类实现了action接口和其他的几个有用的接口,比如数据校验、错误消息本地化等.继承该类后,这些功能便自动获得.一、基本校验public void validate(){ PortfolioService ps = getPortfolioService(); if ( getPassword().length() == 0 ){ ..._actionsupport类

相对熵(KL散度)-程序员宅基地

https://zhuanlan.zhihu.com/p/37452654https://blog.csdn.net/weixinhum/article/details/85064685交叉熵和相对熵相对熵(KL散度)KL 散度:衡量每个近似分布与真实分布之间匹配程度的方法:\[D_{K L}(p \| q)=\sum_{i=1}^{N} p\left(x_{i}\right)..._为什么相对熵是inf

基于Simulink的模糊控制器设计及Matlab源代码_simulink模糊控制-程序员宅基地

现在,我们将这些定义组合起来,构建一个模糊控制器模型。模糊控制器的输入和输出都是模糊变量,其中输入的模糊变量称为“控制量”,输出的模糊变量称为“被控量”。将Fuzzy Membership Function模块的输入设置为误差信号,将其输出连接到Fuzzy Logic Controller模块的输入,将Fuzzy Logic Controller模块的输出连接到一个Scopes模块,然后开始仿真。假设我们要控制一个电机的转速,输入控制量是电机的误差(期望转速与实际转速之差),输出被控量是电机的转速。_simulink模糊控制

DFS客户端访问设置及安全策略-程序员宅基地

除了Windows Server 2003家族中基于服务器的DFS组件外,还有基于客户端的DFS组件。DFS客户端可以将对DFS根目录或DFS链接的引用缓存一段时间,该时间由管理员指定。DFS客户端组件可以在许多不同的Windows平台上运行。Windows Server 2003 家族产品支持下列平台上的目标。   一、 从其他计算机访问DFS目标    表1 支持DFS的操作系统列

Qt放大镜(一)_qt中放大镜_赵民勇的博客-程序员宅基地

版权声明:转载自Qt放大镜代码设计_onlyshi的专栏-程序员宅基地放大镜代码设计一、起因看到自己手机T1上面的那个搜索app的图标是个放大镜,但是锤子科技把它真正做成了也具有放大镜的功能。由于自己刚学Qt,所以也想在电脑上试一下做一个放大镜的小玩意。但是思路有限,对Qt掌握也不是非常好,就很的简单做了一个,提供些思路给有需要的人,但是,做的还不够精细,不够好,希望以后随着对Qt的掌握程度的加深后,会重新做一个更好的。看看做完之后的效果对比图。二、代码实现1、实._qt中放大镜

随便推点

【前端】前端监控体系-程序员宅基地

>对于一个应用来说,除了前期的开发和设计,在项目上线后端维护很重要,其中就包括监控体系的搭建。>系统需要具备发布灰度过程中的监控以及用户问题的反馈和定位等能力。>这些问题可以从2个点解决:数据采集 和 数据上报与监控_前端监控

自定义UI 简易图文混排_图文混排效果图_Notzuonotdied的博客-程序员宅基地

系列文章目录自定义UI 基础知识自定义UI 绘制饼图自定义UI 圆形头像自定义UI 自制表盘文章目录系列文章目录前言创建绘制对象加载图片自定义绘制内容绘制图片绘制文字文本宽高获取测量文字宽度文本绘制的位置实现代码附录源码前言这系列的文章主要是基于扔物线的HenCoderPlus课程的源码来分析学习。扔物线课程源码:ImageTextView.javaAndroid官方文档:自定义绘制这一篇文章主要介绍的是文字的测量,更多的内容可以参考:HenCoder Android 开发_图文混排效果图

TensorFlow2.0教程-AutoGraph-程序员宅基地

TensorFlow2.0教程-AutoGraphtf.function的一个很酷的新功能是AutoGraph,它允许使用自然的Python语法编写图形代码。from __future__ import absolute_import, division, print_functionimport numpy as npimport tensorflow as tffrom tensor...

adb命令在测试中的使用_adb测试命令有什么用-程序员宅基地

输入命令 ./aa.sh如果 没有成功报出问题::Permission denied。就是没有权限。解决办法:修改该文件aa.sh 的权限 :使用命令:chmod 777aa.sh。然后再执行 最上面的操作 就 OK ...._adb测试命令有什么用

运行flutter doctor闪退_Vera-min的博客-程序员宅基地

运行flutter doctor闪退windows cmd/flutter_console.bat 运行flutter/flutter doctor闪退_flutter doctor闪退

什么是前端_... 前端-程序员宅基地

什么是前端前端,也称web前端,通俗一点就是网页,运行在PC端和移动端等浏览器展示给用户的网页前端开发最基本(HTML,CSS,JavaScript)也是最核心,不管是做移动端前端还是PC前端三个都是必不可少的HTMLHTML是超文本标记语言也是制作网页的基础,是用于在Internet上显示Web页面的主要标记语言。网页由HTML组成,用于通过Web浏览器显示文本,图像或其他资源..._... 前端