【mmdetection】使用自定义的coco格式数据集进行训练及测试

2023-11-08

一、mmdetection简介

项目仓库地址:https://github.com/open-mmlab/mmdetection

香港中文大学-商汤科技联合实验室开源了基于 PyTorch 的检测库——mmdetection。商汤科技和港中大组成的团队在 2018年的COCO 比赛的物体检测(Detection)项目中夺得冠军,而 mmdetection 正是基于 COCO 比赛时的 codebase 重构。

这个开源库提供了已公开发表的多种视觉检测核心模块。通过这些模块的组合,可以迅速搭建出各种著名的检测框架,比如 Faster RCNN,Mask RCNN,R-FCN,RetinaNet , Cascade R-CNN及ssd 等,以及各种新型框架,从而大大加快检测技术研究的效率。遗憾的是现在还没有出yolo网络。

相比 FAIR 此前开源的 Detectron,mmdetection 有以下几大优势:

  1. Performance 稍高
  2. 训练速度稍快: Mask R-CNN 差距比较大,其余的很小。
  3. 所需显存稍小: 显存方面优势比较明显,会小 30% 左右。
  4. 易用性更好: 基于 PyTorch 和基于 Caffe2 的 code 相比,易用性是有代差的。

与 mmdetection 一起开源的还有一个基础库——mmcv。 mmcv 基础库主要分为两个部分:一部分是和 deep learning framework 无关的一些工具函数,比如 IO/Image/Video 相关的一些操作;另一部分是为 PyTorch 写的一套训练工具,可以大大减少用户需要写的代码量,同时让整个流程的定制变得容易。项目仓库地址为:https://github.com/open-mmlab/mmcv
建议也把mmcv仓库下载到本地,方便后面debug的时候查看源码。

二、环境安装

1、安装教程

最好按照官方仓库的安装说明进行,很多博客里面的安装方法都或多或少有点小问题,可能是官方仓库一直在更新但是博客没有更新的缘故。官方安装说明地址:https://github.com/open-mmlab/mmdetection/blob/master/INSTALL.md

简单来说,该仓库目前只支持在linux系统上运行,不支持window; PyTorch的版本要求为:PyTorch 1.0+ or PyTorch-nightly,且要根据其官网的安装命令安装,避免版本冲突问题。

2、运行demo测试环境是否安装成功

因为博主之前使用别的博客的demo代码的时候出现错误,找了半天不知道是什么原因,而当我好好看官方说明的时候才知道这个代码在说明中有,而且已经更新过,所以为了保险期间,这里就不直接贴出代码了,给地址你们自己去看。

测试的demo代码地址为:https://github.com/open-mmlab/mmdetection/blob/master/GETTING_STARTED.md#high-level-apis-for-testing-images 。将代码写入py文件,并存放到mmdetection文件夹目录下,然后运行。但是运行官方代码的前提是你已经下载了相关模型的checkpoint的pth文件,并放在mmdetection文件夹目录下的checkpoints文件夹下。官方提供的所有训练好的pth模型文件下载地址都在MODEL_ZOO.md中。另外随便照一张图片重命名为test.jpg放到mmdetection目录下就可以了。

三、训练自定义的dataset

相信大家用这个mmdetection都不只是为了尝尝鲜试一下的吧,所以这里分享下我训练自定义的数据集的过程记录。

先给大家看一下我的整个mmdetection文件夹的内容。
在这里插入图片描述

1、准备dataset

需要说明的是官方提供的所有代码都默认使用的是coco格式的数据集,所以不想太折腾的话就把自己的数据集转化成coco数据集格式吧。各种类型数据转coco格式脚本见:转换工具箱 。我使用的是其中的labelme2coco.py文件,亲测没有问题。

制作好数据集之后,官方推荐coco数据集按照以下的目录形式存储:

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── coco
│   │   ├── annotations
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017

推荐以软连接的方式创建data文件夹,下面是创建软连接的步骤

cd mmdetection
mkdir data
ln -s $COCO_ROOT data

其中,$COCO_ROOT需改为你的coco数据集根目录

2、Training前修改相关文件

首先说明的是我的数据集类别一共有4个,分别是:‘Glass_Insulator’, ‘Composite_Insulator’, ‘Clamp’, ‘Drainage_Plate’。且我跑的模型是’configs/faster_rcnn_r50_fpn_1x.py’

官方提供的代码中都使用的是coco数据集,虽然我们自定义的数据集也已经转换成coco标准格式了,但是像class_name和class_num这些参数是需要修改的,不然跑出来的模型就不会是你想要的。

一些博客例如这个,所提供的方法是按照官方给的定义coco数据集的相关文件,新建文件重新定义自己的数据集和类等,但是其实这是有风险的,我之前按照他们的方法走到最后发现会出现错误,所以最简单便捷且保险的方法是直接修改coco数据集定义文件(官方也是这样建议的)。

1、定义数据种类,需要修改的地方在mmdetection/mmdet/datasets/coco.py。把CLASSES的那个tuple改为自己数据集对应的种类tuple即可。例如:

CLASSES = ('Glass_Insulator', 'Composite_Insulator', 'Clamp', 'Drainage_Plate')

2、接着在mmdetection/mmdet/core/evaluation/class_names.py修改coco_classes数据集类别,这个关系到后面test的时候结果图中显示的类别名称。例如:

def coco_classes():
    return [
        'Glass_Insulator', 'Composite_Insulator', 'Clamp', 'Drainage_Plate'
    ]

3、修改configs/faster_rcnn_r50_fpn_1x.py中的model字典中的num_classes、data字典中的img_scale和optimizer中的lr(学习率)。例如:

num_classes=5,#类别数+1
img_scale=(640,478), #输入图像尺寸的最大边与最小边(train、val、test这三处都要修改)
optimizer = dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001) #当gpu数量为8时,lr=0.02;当gpu数量为4时,lr=0.01;我只要一个gpu,所以设置lr=0.0025

4、在mmdetection的目录下新建work_dirs文件夹

3、Training

python tools/train.py configs/faster_rcnn_r50_fpn_1x.py --gpus 1 --validate --work_dir work_dirs

展示下开始训练的界面:
在这里插入图片描述
训练完之后work_dirs文件夹中会保存下训练过程中的log日志文件、每个epoch的pth文件(这个文件将会用于后面的test测试)

四、Testing

有两个方法可以进行测试。
1、如果只是想看一下效果而不要进行定量指标分析的话,可以运行之前那个demo.py文件,但是要改一下checkpoint_file的地址路径,使用我们上一步跑出来的work_dirs下的pth文件。例如:

checkpoint_file = 'work_dirs/epoch_100.pth'

2、使用test命令。例如:

python tools/test.py configs/faster_rcnn_r50_fpn_1x.py work_dirs/epoch_100.pth --out ./result/result_100.pkl --eval bbox --show

但是使用这个测试命令的时候会报错,报错的情况我也在官方库的issue上提交了,可以查看我的error描述,看看与你的是否一致。

根据我的问题描述可以知道使用demo.py来测试是可以出结果的,但是会出现”warnings.warn('Class names are not saved in the checkpoint’s ’ "的警告信息。使用这一步的test命令的时候会报错,程序中断,但是其实问题是一致的,应该是训练中保存下来的pth文件中没有CLASSES信息,所以show不了图片结果。因此需要按照下面的步骤修改下官方代码才可以。
(1) 修改mmdetection/mmdet/tools/test.py中的第29行为:

if show:
    model.module.show_result(data, result, dataset.img_norm_cfg, dataset='coco')

最后展示效果如下:
在这里插入图片描述
在这里插入图片描述
此处的格式化输出称为检测评价矩阵(detection evaluation metrics)。此处摘录COCO数据集文档中对该评价矩阵的简要说明:

Average Precision (AP):
	AP		% AP at IoU=.50:.05:.95 (primary challenge metric) 
	APIoU=.50	% AP at IoU=.50 (PASCAL VOC metric) 
	APIoU=.75	% AP at IoU=.75 (strict metric)
AP Across Scales:
	APsmall		% AP for small objects: area < 322 
	APmedium	% AP for medium objects: 322 < area < 962 
	APlarge		% AP for large objects: area > 962
Average Recall (AR):
	ARmax=1		% AR given 1 detection per image 
	ARmax=10	% AR given 10 detections per image 
	ARmax=100	% AR given 100 detections per image
AR Across Scales:
	ARsmall		% AR for small objects: area < 322 
	ARmedium	% AR for medium objects: 322 < area < 962 
	ARlarge		% AR for large objects: area > 962

如果大家按照我的步骤走下来出现什么问题的话欢迎在评论去留言,不知道我有没记录漏了哪个步骤!

五、实践中的Tips

  1. 如果实践中修改了mmcv的相关代码,需要到mmcv文件夹下打开终端,激活mmdetection环境,并运行"pip install ."后才会生效(这样修改的代码才会同步到anaconda的mmdetection环境配置文件中)
  2. 若想使用tensorboard可视化训练过程,在config文件中修改log_config如下:
log_config = dict(
    interval=10,                           # 每10个batch输出一次信息
    hooks=[
        dict(type='TextLoggerHook'),       # 控制台输出信息的风格
        dict(type='TensorboardLoggerHook')  # 需要安装tensorflow and tensorboard才可以使用
    ])

六、好资料推荐阅读

为了更好地理解mmdetection,现在好的资料还比较少,但还是有的,下面给大家推荐一些可能会对你有帮助的。

  1. mmdetection - 基于PyTorch的开源目标检测系统(重点推荐)
  2. 详解目标检测Faster R-CNN
  3. mmdetection的configs中的各项参数具体解释
  4. Labelimg标注的VOC格式数据集转成coco格式
  5. mmdetection源码阅读笔记(0)–创建模型
  6. mmdetection源码阅读笔记(1)–创建网络
  7. 一个讲解比较详细的源代码的私人博客
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【mmdetection】使用自定义的coco格式数据集进行训练及测试 的相关文章

随机推荐

  • 使用 Stable Diffusion 生成的仿旧照片和二次元图片

    这几天在电脑上运行 Stable Diffusion 玩了玩 这是我机器上的测试页面 https qizhen xyz genimg 这个模型比 Dall E 的小很多 所以才能在配置不高的个人电脑上跑 而且 我的电脑也只能勉强生成小尺寸的
  • 2023深圳杯(东三省)数学建模D题思路 - 基于机理的致伤工具推断

    1 赛题 D题 基于机理的致伤工具推断 致伤工具的推断一直是法医工作中的热点和难点 由于作用位置 作用方式的不同 相同的致伤工具在人体组织上会形成不同的损伤形态 不同的致伤工具也可能形成相同的损伤形态 致伤工具品种繁多 形态各异 但大致可分
  • 基于JAVA+SpringBoot+Vue+ElementUI中学化学实验室耗材管理系统

    全网粉丝20W csdn特邀作者 博客专家 CSDN新星计划导师 java领域优质创作者 博客之星 掘金 华为云 阿里云 InfoQ等平台优质作者 专注于Java技术领域和毕业项目实战 文末获取项目下载方式 一 项目背景介绍 当前 中学的化
  • MATLAB的统计每个列向量的个数

    tabulate 变量名 例子 统计age列向量里面有多少个不同年龄的个数 tabulate age 下面还有很多太长了 没有截图
  • [Unity3D]关于Android真机调测Profiler

    Unity3D 关于Android真机调测Profiler 2013 08 25 13 28 50 转载 标签 android profiler adb it 分类 Unity3d U3D中的Profile也是可以直接在链接安卓设备运行游戏
  • 280场周赛

    6004 得到 0 的操作数 给你两个 非负 整数 num1 和 num2 每一步 操作 中 如果 num1 gt num2 你必须用 num1 减 num2 否则 你必须用 num2 减 num1 例如 num1 5 且 num2 4 应
  • 悬赏百万美金检测Deepfake假视频,数据集470G:比赛很久没这么壕

    2019 12 13 13 51 52 车栗子 发自 凹非寺 量子位 报道 公众号 QbitAI 谁说Kaggle比赛都那么穷 穷不穷 还要看做的是什么任务 比如 有左右两段视频 你能分辨哪个是修过的么 动图结尾公布了答案 右是原始视频 左
  • 刷视频课的脚本

    是不是不想上视频课 是不是被迫要上视频课 是不是视频课很长 是不是如果挂机短短几分钟就会出现自动暂停的情况 是不是还在为这些烦恼 那么 掌声 只需一台空置的电脑 这个代码可以为你解决这些烦恼 话不多说 上代码 import time imp
  • QT moveToThread线程理解

    一 moveToThread创建开启线程步骤 1 创建继承自QObject类 实现槽函数 2 将QObject类通过moveToThread方法移到QThread线程中 使QObject类依附于线程 3 连接信号槽 槽必须是QObject类
  • [LeetCode-70]-Climbing Stairs(爬楼梯,斐波那契数列问题)

    文章目录 题目相关 Solution 题目相关 题目解读 该题就是斐波那契数列问题 可以使用递归方法实现 原题描述 原题链接 You are climbing a stair case It takes n steps to reach t
  • NFS详细介绍

    NFS介绍 网络文件系统 network files system 简称NFS是一种基于TCP传输协议的文件共享习通 NFS的CS体系中的服务端启用协议将文件共享到网络上 然后允许本地NFS客户端通过网络挂载服务端共享的文件 应用场景 为w
  • LeetCode 题 -7. 整数反转

    题目 给出一个 32 位的有符号整数 你需要将这个整数中每位上的数字进行反转 示例 1 输入 123 输出 321 示例 2 输入 123 输出 321 示例 3 输入 120 输出 21 注意 假设我们的环境只能存储得下 32 位的有符号
  • JS逆向之网易云音乐

    文章目录 1 目标网站 2 初步分析 3 定位加密参数生成位置 4 编码测试 4 1 定义AES加密方法 4 2 调用两次AES加密获取params 4 3 获取歌曲的url 4 4 单曲下载初步测试代码 4 5 飙升榜单音乐批量抓取 文章
  • MySql中把一个表的数据插入到另一个表中

    1 如果2张表的字段一致 并且希望插入全部数据 可以用这种方法 INSERT INTO 目标表 SELECT FROM 来源表 例如 insert into insertTest select from insertTest2 2 如果只希
  • 2020年加密货币领域的5大做市商,都有谁?

    什么是加密货币做市 与传统做市商相比 加密货币做市是一个新的事物 本文旨在更好地了解加密货币做市商的行为 首先 让我们通过探索对做市流程的基本了解来研究什么是做市 简而言之 做市是一种交易活动 交易员同时向金融市场上的交易双方 买方和卖方
  • 超详细图解!【MySQL进阶篇】MySQL架构原理

    MySQL体系架构 MySQL Server架构自顶向下大致可以分网络连接层 服务层 存储引擎层和系统文件层 一 网络连接层 客户端连接器 Client Connectors 提供与MySQL服务器建立的支持 目前几乎支持所有主流 的服务端
  • 基于人工蜂群算法的函数寻优算法

    文章目录 一 理论基础 二 算法流程 1 初始化阶段 2 引领蜂阶段 3 跟随蜂阶段 4 侦察蜂阶段 5 食物源 三 MATLAB程序实现 1 清空环境变量 2 问题设定 3 参数设置 4 初始化蜜蜂种群 5 迭代优化 6 结果显示 四 参
  • php 发送邮箱 Email

    步骤一 phpmailer 很好 无论原生还是放到框架下 phpmailer下载地址 https github com PHPMailer PHPMailer
  • 华中农业大学数学实验期末考试答案(matlab)

    1 这题通过生成一个全是1的矩阵 然后加上一个对角阵就行了 A ones 10 10 3 diag 0 9 DET A det A INV A inv A 2 这一题之前我写过一个求线性方程组的小程序 求解线性方程组 3 function
  • 【mmdetection】使用自定义的coco格式数据集进行训练及测试

    目录 一 mmdetection简介 二 环境安装 1 安装教程 2 运行demo测试环境是否安装成功 三 训练自定义的dataset 1 准备dataset 2 Training前修改相关文件 3 Training 四 Testing 五