使用MMDetection训练自己的数据集

2023-11-04

github链接:OpenMMLab (github.com)

官方文档:Prerequisites — MMDetection 2.15.1 documentation

MMDetection推荐大家最好还是在linux系统下使用,windows系统上使用起来属实bug太多

下面的教程将会教会大家如何使用MMDetection来训练一个自己的目标检测模型,MMDetection设计的非常nice,准备好数据之后,只需要稍微修改一下配置文件就能完成训练,大多数模型的配置文件在MMDetection都进行了提供,只需要继承这些配置文件并重写其中的一些参数即可。

安装MMDetection

首先,通过下面的命令检查你的nvcc和gcc的版本,其中nvcc是调用gpu的关键,gcc是编译代码的关键。

# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version

你的电脑将会输出下列信息:

image-20210827194213135

然后大家需要安装mmdetection,mmdetection是openmmlab提供的一个计算机视觉的目标检测组件,他还提供了语义分割,分类等多种计算机视觉组件库,这些组件库基本都依赖与mmcv,安装的时候一定要注意保持mmcv和组件库的版本匹配,比如下图是mmcv和mmdetection的匹配关系。

MMDetection version MMCV version
master mmcv-full>=1.3.8, <1.4.0
2.15.1 mmcv-full>=1.3.8, <1.4.0
2.15.0 mmcv-full>=1.3.8, <1.4.0
2.14.0 mmcv-full>=1.3.8, <1.4.0
2.13.0 mmcv-full>=1.3.3, <1.4.0
2.12.0 mmcv-full>=1.3.3, <1.4.0
2.11.0 mmcv-full>=1.2.4, <1.4.0
2.10.0 mmcv-full>=1.2.4, <1.4.0
2.9.0 mmcv-full>=1.2.4, <1.4.0
2.8.0 mmcv-full>=1.2.4, <1.4.0
2.7.0 mmcv-full>=1.1.5, <1.4.0
2.6.0 mmcv-full>=1.1.5, <1.4.0
2.5.0 mmcv-full>=1.1.5, <1.4.0
2.4.0 mmcv-full>=1.1.1, <1.4.0
2.3.0 mmcv-full==1.0.5
2.3.0rc0 mmcv-full>=1.0.2
2.2.1 mmcv==0.6.2
2.2.0 mmcv==0.6.2
2.1.0 mmcv>=0.5.9, <=0.6.1
2.0.0 mmcv>=0.5.1, <=0.5.8

如果你是jupyter的环境,你可以执行下面的命令完成安装。

# install dependencies: (use cu101 because colab has CUDA 10.1)
!pip install -U torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html

# install mmcv-full thus we could use CUDA operators
!pip install mmcv-full

# Install mmdetection
!rm -rf mmdetection
!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection

!pip install -e .

# install Pillow 7.0.0 back in order to avoid bug in colab
!pip install Pillow==7.0.0

并执行下面的python代码来检查是否安装成功。

# Check Pytorch installation
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# Check MMDetection installation
import mmdet
print(mmdet.__version__)

# Check mmcv installation
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
print(get_compiling_cuda_version())
print(get_compiler_version())

如果安装成功之后,将会在你的命令行中输出下列的信息。

image-20210827194727581

或者你可以通过下面的代码来使用他官方提供的maskrnn的模型。

!mkdir checkpoints
!wget -c https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth \
      -O checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth
from mmdet.apis import inference_detector, init_detector, show_result_pyplot

# Choose to use a config and initialize the detector
config = 'configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco.py'
# Setup a checkpoint file to load
checkpoint = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
# initialize the detector
model = init_detector(config, checkpoint, device='cuda:0')
# Use the detector to do inference
img = 'demo/demo.jpg'
result = inference_detector(model, img)
# Let's plot the result
show_result_pyplot(model, img, result, score_thr=0.3)

效果如下:

image-20210827194906755

准备数据

官方文档:Tutorial 2: Customize Datasets — MMDetection 2.15.1 documentation

目标检测的数据大多数需要处理成voc或者coco的格式,其中voc的格式是xml文件,bbox是左上角和右下角的坐标,coco是一个json文件,bbox是左上角的坐标和宽高。下面我们将会使用一个小规模的kitti数据集来作为我们使用的数据集,下载地址如下:

# download, decompress the data
!wget https://download.openmmlab.com/mmdetection/data/kitti_tiny.zip
!unzip kitti_tiny.zip > /dev/null

数据集的格式如下:

# Check the directory structure of the tiny data
# Install tree first
!apt-get -q install tree
!tree kitti_tiny

# 数据集格式 images目录是是图片,labels目录下是标签,train和val分别记录了训练和验证使用到的数据
kitti_tiny
├── training
│   ├── image_2
│   │   ├── 000000.jpeg
│   │   ├── 000001.jpeg
│   │   ├── 000002.jpeg
│   │   ├── 000003.jpeg
│   └── label_2
│       ├── 000000.txt
│       ├── 000001.txt
│       ├── 000002.txt
│       ├── 000003.txt
├── train.txt
└── val.txt

可以通过下面的代码来查看一下图片大致是什么样子的

# Let's take a look at the dataset image
import mmcv
import matplotlib.pyplot as plt

img = mmcv.imread('kitti_tiny/training/image_2/000073.jpeg')
plt.figure(figsize=(15, 10))
plt.imshow(mmcv.bgr2rgb(img))
plt.show()

image-20210827195704071

训练模型

准备好数据之后,我们只需要修改我们的配置文件即可完成训练:

首先需要加载基本的配置文件,在configs目录下你可以找到这些配置文件,比如这里我们加载的是faster_rcnn的配置文件。

from mmcv import Config
cfg = Config.fromfile('./configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')

修改并将修改之后的配置文件保存,在后面推理的时候我们可以直接加载我们的配置文件。

from mmdet.apis import set_random_seed

# Modify dataset type and path
cfg.dataset_type = 'KittiTinyDataset'
cfg.data_root = 'kitti_tiny/'

cfg.data.test.type = 'KittiTinyDataset'
cfg.data.test.data_root = 'kitti_tiny/'
cfg.data.test.ann_file = 'train.txt'
cfg.data.test.img_prefix = 'training/image_2'

cfg.data.train.type = 'KittiTinyDataset'
cfg.data.train.data_root = 'kitti_tiny/'
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'training/image_2'

cfg.data.val.type = 'KittiTinyDataset'
cfg.data.val.data_root = 'kitti_tiny/'
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'training/image_2'

# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 3
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './tutorial_exps'

# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10

# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12

# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)


# We can initialize the logger for training and have a look
# at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
# 保存模型的各种参数(一定要记得嗷)
cfg.dump(F'{cfg.work_dir}/customformat_kitti.py')

然后训练就可以了

from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector


# Build dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_detector(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)

经过漫长的训练,你将会得到下面的训练记录,并生成日志文件

image-20210827200237662

?如何从训练日志中获取信息

从日志中,我们可以对训练过程有一个基本的了解,知道检测器的训练效果如何。

首先,加载在 ImageNet 上预训练的 ResNet-50 主干,这是一种常见做法,因为从头开始训练成本更高。日志显示除了 conv1.bias 之外,ResNet-50 主干的所有权重都被加载,它已合并到 conv.weights 中。

其次,由于我们使用的数据集很小,我们加载了一个 Mask R-CNN 模型并对其进行了微调以进行检测。因为我们实际使用的检测器是 Faster R-CNN,所以掩码分支中的权重,例如roi_head.mask_head,是源 state_dict 中的意外键,未加载。原始的 Mask R-CNN 在包含 80 个类的 COCO 数据集上进行训练,但 KITTI Tiny 数据集只有 3 个类。因此,用于分类的预训练Mask R-CNN的最后一个FC层具有不同的权重形状,未使用。

第三,训练后,检测器通过默认的 VOC 式评估进行评估。结果表明,检测器在 val 数据集上达到了 54.1 mAP,不错!

使用训练好的模型

如果你是jupyter的代码,你可以继续执行下列的文件来使用训练好的模型。

img = mmcv.imread('kitti_tiny/training/image_2/000068.jpeg')

model.cfg = cfg
result = inference_detector(model, img)
show_result_pyplot(model, img, result)

如果你是在pycharm等工具中完成的开发,可以参考这篇博客使用你的模型。

使用MMDetection进行目标检测_dejahu的博客-CSDN博客

最后附上完整的训练代码

from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.apis import set_random_seed
import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
import warnings
warnings.filterwarnings('ignore')

@DATASETS.register_module()
class KittiTinyDataset(CustomDataset):
    CLASSES = ('Car', 'Pedestrian', 'Cyclist')
    def load_annotations(self, ann_file):
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        # load image list from file
        image_list = mmcv.list_from_file(self.ann_file)

        data_infos = []
        # convert annotations to middle format
        for image_id in image_list:
            filename = f'{self.img_prefix}/{image_id}.jpeg'
            image = mmcv.imread(filename)
            height, width = image.shape[:2]

            data_info = dict(filename=f'{image_id}.jpeg', width=width, height=height)

            # load annotations
            label_prefix = self.img_prefix.replace('image_2', 'label_2')
            lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))

            content = [line.strip().split(' ') for line in lines]
            bbox_names = [x[0] for x in content]
            bboxes = [[float(info) for info in x[4:8]] for x in content]

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            # filter 'DontCare'
            for bbox_name, bbox in zip(bbox_names, bboxes):
                if bbox_name in cat2label:
                    gt_labels.append(cat2label[bbox_name])
                    gt_bboxes.append(bbox)
                else:
                    gt_labels_ignore.append(-1)
                    gt_bboxes_ignore.append(bbox)

            data_anno = dict(
                bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                labels=np.array(gt_labels, dtype=np.long),
                bboxes_ignore=np.array(gt_bboxes_ignore,
                                       dtype=np.float32).reshape(-1, 4),
                labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

            data_info.update(ann=data_anno)
            data_infos.append(data_info)

        return data_infos

cfg = Config.fromfile('./configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')
# Modify dataset type and path
cfg.dataset_type = 'KittiTinyDataset'
cfg.data_root = 'data/kitti_tiny/'
cfg.data.test.type = 'KittiTinyDataset'
cfg.data.test.data_root = 'data/kitti_tiny/'
cfg.data.test.ann_file = 'train.txt'
cfg.data.test.img_prefix = 'training/image_2'
cfg.data.train.type = 'KittiTinyDataset'
cfg.data.train.data_root = 'data/kitti_tiny/'
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'training/image_2'
cfg.data.val.type = 'KittiTinyDataset'
cfg.data.val.data_root = 'data/kitti_tiny/'
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'training/image_2'
# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 3
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
# Set up working dir to save files and logs.
cfg.work_dir = './tutorial_exps'
# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10
# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
# We can initialize the logger for training and have a look
# at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
# 保存模型的各种参数(一定要记得嗷)
cfg.dump(F'{cfg.work_dir}/customformat_kitti.py')

# 训练主要进程
# Build dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_detector(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用MMDetection训练自己的数据集 的相关文章

随机推荐

  • 数据结构-第三章 栈和队列

    Stack and Queue 栈和队列是逻辑上的结构 在物理上可以用数组和链表来实现 1 栈 A stack is a list in which insertions and deletions take place at the sa
  • 逆向爬虫31 某站刷播放

    逆向爬虫31 某站刷播放 目标 利用爬虫模拟某站视频播放 增加视频的播放量 思考 正常用户是如何为视频增加播放量的 进入视频播放页 点击播放按钮 视频开始播放 就会增加一个播放量 因此我们只需要模拟点击播放按钮时 浏览器对服务器发送的数据包
  • python 字符串True,False转换成布尔值True,False

    字符串True False转换成布尔值True False不能用bool函数 因为得到的结果都是布尔值True 可以写个if判断 if ss True ss True elif ss False ss False
  • MySQL基本命令

    登录mysql hhostname Pport uusername p 比如 mysql hlocalhost P3306 uroot p 主机名 端口号 用户名 密码 同一台服务器上前两个省略 显示所有数据库 show databases
  • zabbix监控nginx状态界面

    文章目录 开启状态界面 监控nginx状态界面 开启状态界面 实例 开启status location status stub status on off allow 172 16 0 0 16 deny all 访问状态页面的方式 htt
  • 编译工具 Ninja 介绍

    什么是Ninja Ninja是使用C 写的开源项目 http martine github io ninja 在Unix Linux下通常使用Makefile来控制代码的编译 但是Makefile对于比较大的项目有时候会比较慢 看看上面那副
  • (手工)【sqli-labs26、26a】拼接注入、过滤后注入

    目录 推荐 一 手工 SQL注入基本步骤 二 Less25 GET Error based All your SPACES and COMMENTS belong to us 2 1 简介 过滤 报错回显 2 2 第一步 注入点测试 2 3
  • 性能测试浅谈

    早期的性能测试更关注后端服务的处理能力 一个用户去访问一个页面的请求过程 如上图 数据传输时间 当你从浏览器输入网址 敲下回车 开始 真实的用户场景请不要忽视数据传输时间 想想你给远方的朋友写信 信件需要经过不同的交通运输工具送到朋友手上
  • Python __init__.py 模块详解

    文章目录 1 概述 2 导入演示 2 1 执行顺序 先父后子 2 2 导入所有模块 含子模块 1 概述 1 工具 Pycharm 场景 在创建一个 Python Package 时 会默认在该包下生成一个 init py 文件 2 目的 进
  • matlab中rem与mod函数的区别

    语法格式 rem x y 求整除x y的余数 mod x y 求模 rem x y x y fix x y fix 向0取整 mod x y x y floor x y floor 向左取整 以数抽为准 朝负无穷方向取整 如果x和y的符号相
  • SQLlite

    SQLlite SQLite是一个软件库 实现了自给自足的 无服务器的 零配置的 事务性的 SQL 数据库引擎 一 什么是 SQLite SQLite是一个进程内的库 实现了自给自足的 无服务器的 零配置的 事务性的 SQL 数据库引擎 它
  • uTools使用技巧

    uTools 提高工作效率 学习效率 启动uTools Alt 空格 关键词 任何系统文件 软件 插件 都可以通过 关键词 快速跳转 快速打开文件 软件 输入 控制面板 选中后就能跳转到 控制面板 同样的 程序与功能 cmd 等系统文件 都
  • 电脑恢复还原文件的各种操作方法

    如果你的电脑因操作不慎丢失了重要的数据 先不要给电脑重装系统 一般来说都是可以根据各种类型去找回这些文件的 这里就和大家介绍一下电脑恢复还原文件的各种操作方法吧 1 首先是U盘和内存卡类型的数据 u盘是我们经常使用的移动储存工具了 在对这些
  • 设计模式——网课学习总结

    面向对象 设计模式七大基本原则 单一职责原则 SRP 一个类的功能要单一 提高内聚性 方法要原子性 开放封闭原则 OCP 对扩展性开放 对修改封闭 最重要 总纲 里氏替换原则 LSP 子类继承父类 子类不要改变父类原有的方法 完成新的功能需
  • 概念:COW与MOR

    名词解释 COW 写时复制 MOR 读时合并 CopyOnWrite 思想 写时复制 CopyOnWrite 简称COW 思想是计算机程序设计领域中的一种通用优化策略 其核心思想是 如果有多个调用者 Callers 同时访问相同的资源 如内
  • 8个开源的后台管理系统推荐,用了都说好

    点击上方蓝字 关注我们 1 AG Admin AG Admin是国内首个基于Spring Cloud微服务化开发平台 具有统一授权 认证后台管理系统 其中包含具备用户管理 资源权限管理 网关API管理等多个模块 支持多业务系统并行开发 可以
  • 看这里!java架构师教学视频全百度云

    为了更好的梳理相关知识 咱们先看纯手绘知识体系图 1 1 Kafka知识体系大纲 由于我手绘这些知识体系大纲是用的xmind软件 无法上传 所以都以截图的形式展示 细节处不清楚 毕竟图片形式有限 1 2 RabbitMQ知识体系大纲 1 3
  • MQTT服务器搭建及客户端通信实例

    MQTT服务器 EMQX v3 客户端1 PC Windows10操作系统 客户端2 IOT BOARD RT Thread与正点原子联合开发的STM32L475核心芯片的开发板 1 搭建服务器 在EMQ官网https www emqx i
  • Django DRF全局开启模糊查询

    DjangoFilterBackend或者RestFrameworkFilterBackend在做指定字段查询时 默认为精确查询 如 api v1 brand brands name huawei可以查到 api v1 brand bran
  • 使用MMDetection训练自己的数据集

    github链接 OpenMMLab github com 官方文档 Prerequisites MMDetection 2 15 1 documentation MMDetection推荐大家最好还是在linux系统下使用 windows