mmclassification 训练自定义数据

2023-11-12

1 mmclassification 安装

        如果环境已安装mmclassification,请跳过该步骤。mmclassification框架安装与调试验证请参考博客:mmclassification安装与调试_Coding的叶子的博客-CSDN博客_mmclassification 安装

2 数据集准备

        mmclassification 的数据集目录主要由标注文件和图片样本组成,其中标注文件存储在meta文件夹中,图片样本存在train、val、test文件夹下,即分别是用于训练、验证和测试的图片样本。图片样本文件按照类别存储在train、val、test文件夹下,同一类别图片存储在同一个子文件夹中,子文件夹的名称为图片所属类别名称。

        meta文件夹中主要包含了train.txt、val.txt和test.txt文件。txt文件中的每一行分别存储了图片样本路径和类别id,如下图所示。

        如果没有meta标注文件,请参考博客:mmclassification 标注文件生成_Coding的叶子的博客-CSDN博客,生成meta文件夹及其文件夹下的txt文件。

         本文示例数据来源于minist手写字体可视化数据集,已按照train、test文件夹进行存储,下载地址为:minist手写数字可视化数据集-深度学习文档类资源-CSDN下载

        将下载的数据集文件夹名称重名为Minist,并且mmclassification工程目录下新建data文件夹,将数据集放到data文件夹下即可。数据集的存储路径不限,需要在下方3.3节中配置相应的路径即可。

3 自定义数据集

3.1 新建MyDataset

        在mmclassification工程目录下的mmcls/datasets/新建mydataset.py文件,自定义数据加载类MyDataset,文件名称mydataset和类名称MyDataset可以自行更改。mydataset.py文件中的内容如下: 

# -*- coding: utf-8 -*-
"""
乐乐感知学堂公众号
@author: https://blog.csdn.net/suiyingy
"""

import numpy as np

from .builder import DATASETS
from .base_dataset import BaseDataset



@DATASETS.register_module()
class MyDataset(BaseDataset):

    def load_annotations(self):
        assert isinstance(self.ann_file, str)

        data_infos = []
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                info = {'img_prefix': self.data_prefix}
                info['img_info'] = {'filename': filename}
                info['gt_label'] = np.array(gt_label, dtype=np.int64)
                data_infos.append(info)
            return data_infos

 3.2 将MyDataset注册到mmclassification框架

        在mmcls/datasets/__init__.py文件中增加上面定义的类MyDataset,如下图所示:

 3.3 新建数据集配置文件

        在mmclassification工程目录configs/_base_/datasets/文件夹下,新建mydataset.py文件,主要用于设置数据集类型、数据增强方式、batch size (samples_per_gpu)、数据集路径和标注文件路径、模型保存周期(interval)。文件内容如下所示:

# -*- coding: utf-8 -*-
"""
乐乐感知学堂公众号
@author: https://blog.csdn.net/suiyingy
"""
dataset_type = 'MyDataset'
classes = ['cat', 'bird', 'dog']  # The category names of your dataset

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]

data = dict(
    train=dict(
        type=dataset_type,
        data_prefix='data/Minist/train',
        ann_file='data/Minist/meta/train.txt',
        classes=classes,
        pipeline=train_pipeline
    ),
    val=dict(
        type=dataset_type,
        data_prefix='data/Minist/test',
        ann_file='data/Minist/meta/test.txt',
        classes=classes,
        pipeline=test_pipeline
    ),
    test=dict(
        type=dataset_type,
        data_prefix='data/Minist/test',
        ann_file='data/Minist/meta/test.txt',
        classes=classes,
        pipeline=test_pipeline
    )
)
evaluation = dict(interval=1, metric='accuracy')

4 修改configs模型配置文件

        以configs/resnet/resnet18_8xb16_cifar10.py配置文件为例,mmclassification的配置文件通常包含以下4个部分:

_base_ = [
    '../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
    '../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]

        ../_base_/models/resnet18_cifar.py:定义模型参数,主要包括主干网络、neck、head和类别数量。

        ../_base_/datasets/cifar10_bs16.py:定义数据集增强方式和路径,也就是3.3节的配置文件,bs16表示batch size为16,即samples_per_gpu=16。

        ../_base_/schedules/cifar10_bs128.py:定义训练参数,主要包括优化器、学习率、训练总epoch数量。

        ../_base_/default_runtime.py:定义运行参数,主要包括模型保存周期、日志输出周期等。

        configs主要修改的地方为数据配置文件,即把 '../_base_/datasets/cifar10_bs16.py'更换成3.3节中的配置文件'../_base_/datasets/mydataset.py'。即:

5 运行训练程序

        mmcls基本的训练命令为:

python tools/train.py 模型配置文件

        示例:

python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py

        这里已经把resnet18_8xb16_cifar10.py文件按照第4节进行了修改。

6 运行结果

 【python三维深度学习】python三维点云从基础到深度学习_Coding的叶子的博客-CSDN博客_python 三维点云

更多三维、二维感知算法和金融量化分析算法请关注“乐乐感知学堂”微信公众号,并将持续进行更新。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mmclassification 训练自定义数据 的相关文章

随机推荐

  • 基于STM32的ESP8266天气时钟(2)--------MCU获取天气数据

    前言 上一章节我们通过串口调试助手 成功获取到天气数据 这一节我们将通过MCU的串口发送 接收功能 实现MCU获取天气数据 传送门 基于STM32F103的网络天气时钟 1 通过串口获取天气 文章目录 1 摘要 2 硬件准备 2 1 WIF
  • fabric.js保存自定义属性重新渲染

    项目开发中 用到了fabric js开发一系列画布功能 其中 在做画布的序列化和反序列化功能时候 自定义的属性并不能够直接保存下来 这时候则需要我们自己做一些事情啦 下面是我的一个解决方法 在添加画布对象到画布上面之前 修改画布的toObj
  • 学习笔记之30个常用的maven命令

    maven 命令的格式为 mvn plugin name goal name 可以接受的参数如下 D 指定参数 如 Dmaven test skip true 跳过单元测试 P 指定 Profile 配置 可以用于区分环境 e 显示mave
  • 用react 写一个lyout页面

    可以使用 React 创建一个布局页面 首先 你需要在你的项目中安装 React 和 ReactDOM 然后你就可以开始创建你的布局页面了 你可以使用 JSX 语法在你的布局页面中创建 HTML 元素 例如 下面的代码创建了一个带有一个标题
  • Linux下libxml库编程(二)

    http leansmall blog 163 com blog static 51617691200811171560790 3 3 重要操作 3 3 1 创建XML文档 创建一个xml文档流程如下 l 用xmlNewDoc函数创建一个文
  • 【2】Midjourney注册

    随着AI技术的问世 2023年可以说是AI爆炸性成长的一年 近期最广为人知的AI服务除了chatgpt外 就是从去年五月就已经问世的AI绘画工具mid journey了 几个AI工具也代表了人工智能的热门阶段 只要输入一段文字 AI就会根据
  • 2021-07-19PHP面试笔试题记录

    1 执行以下代码 输出结果是 正确结果为 echo class b something 2 执行以下代码 输出结果是
  • vue2 cli4 打包chunk文件太多解决办法

    由于项目原因npm run build打包后chunk文件很多下 想减少chunk文件数量 在vue config js文件中添加webpack插件 文件头加var webpack require webpack 这样chunk文件数量就变
  • 华为交换机配置MSTP

    文章目录 1 拓扑图 2 任务描述 3 SW1配置 4 SW2配置 5 SW3配置 6 SW4配置 1 拓扑图 2 任务描述 在交换机SW1 SW2 SW3 SW4上配置MSTP防止二层环路 具体要求如下 VLAN10数据流默认经过SW3转
  • 程序媛菜鸡面经(八 - offer篇)

    投简历 简历是要多投的 但是有时候投多了简历也会有问题 头条 没有面试机会 在看过简历后HR发邮件告知我 从简历上能看出你是一位很优秀的人 但看不出你在前端 技术方面的竞争力 当时投的是旧版简历 于是我回邮问简历有误能否重申 至今未有回音
  • 子网掩码的作用

    IP地址由网络和主机两部分标识组成 IP地址由 网络标识 网络地址 和 主机标识 主机地址 两部分组成 在局域网内相互间通信的网络必须具有相同网络地址 也叫相同的网段 在同一个网段内每个设备的主机地址都不相同 在IPV4中 IP地址由32位
  • Vue中query与params两种传参的区别

    query语法 this router push path 地址 query id 123 这是传递参数 this route query id 这是接受参数 params语法 this router push name 地址 params
  • linux系统哪个好用

    linux系统哪个好用 1 Ubuntu服务器 Ubuntu是众所周知的最佳LinuxServerDistro 它能为您提供出色的用户体验 如果你是Linux世界的新手 选择Ubuntu作为你的服务器发行版将是最好的 使用此服务器 您可以做
  • Mac系统如何在圣诞节让电脑屏幕下雪?

    对于苹果 Mac 电脑上的 终端 应用 可能大家在平时用得不多 所以对它应该都会比较陌生 其实这个终端应用是用于让用户可以直接输入一些系统指令 让它执行相应的操作 比如简单的显示当前目录中的文件 显示日期与时间 删除文件等操作都是可以的 今
  • Android项目Gradle: Download gradle-6.5-bin.zip一直卡住解决方法

    1 首先停止gradle的下载 通过迅雷或浏览器将gradle下载下来 下载地址为 https services gradle org distributions gradle 6 5 bin zip 其他版本的gradle同理 2 打开C
  • 二级MS Office高级应用

    1 在长度为n的有序线性表中进行二分查找 最坏的情况下需要比较的次数是 O log2n 以2为底n对数 解析 当有序线性表为顺序存储时才可以用二分查找 可以证明的是对于长度为n的有序线性表 最坏的情况下 二分查找只需要比较O log2n 次
  • 数据仓库开发之路之一--准备工作

    在数据仓库的开发过程中 需要熟悉大量的概念以及相关工具的使用 还需要了解宏观上的各种开发流程 串联起来完成最终的数据仓库项目的开发 本篇介绍一些准备工作 包括涉及到的工具介绍 以及开发过程的描述 记录学习研究的印记 并和大家讨论研究存在的相
  • conda upgrade --all惹的祸,该怎么解决?

    本想要安装scikit surprise库 由于环境问题 就更新一下 谁知道差点酿成大祸 anaconda不灵了 无论什么语句都报错 jupyter notebook 不能用 navigator也打不开 万念俱灰了 导致我想要重装anaco
  • atx860和java_捷安特XTC800和ATX860有什么区别

    展开全部 区别比较大 简单说 ATX 8xx就是e69da5e887aa62616964757a686964616f31333431353237ATX 6xx的 局部升级 轮组由26寸换为27 5寸 车架外观改进 变速套件等级略微提高 仅此
  • mmclassification 训练自定义数据

    1 mmclassification 安装 如果环境已安装mmclassification 请跳过该步骤 mmclassification框架安装与调试验证请参考博客 mmclassification安装与调试 Coding的叶子的博客 C