利用MMSegmentation微调Mask2Former模型

2023-11-03

前言

  • 本文介绍了专用于语义分隔模型的pythonmmsegmentationgithub项目地址,运行环境为Kaggle notebookGPUP100
  • 针对环境配置、预训练模型推理、在西瓜数据集上微调新sota模型mask2former模型,数据说明
  • 由于西瓜数据集较小,我们最后在组织病理切片肾小球数据集上微调了mask2former模型,数据说明
  • 该教程有部分参考github项目MMSegmentation_Tutorials项目地址

环境配置

  • 跑通代码需要openmimmmsegmentationmmenginemmdetectionmmcv环境,mmcv环境在kaggle配置比较麻烦,需要预配置包,这里我将所有预配置包都打包好了,放到了数据集frozen-packages-mmdetection中,详情页
import IPython.display as display
!pip install -U openmim

!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -v -e .

!pip install "mmdet>=3.0.0rc4"

!pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl

!pip install wandb
display.clear_output()
  • 实测运行上述代码,在kaggle中可以达到运行项目需求,无报错(2023年7月13日)。
  • 导入常用基础包
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Config

import matplotlib.pyplot as plt
%matplotlib inline

from mmseg.datasets import cityscapes
from mmseg.utils import register_all_modules
register_all_modules()

from mmseg.datasets import CityscapesDataset
from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import init_model, inference_model, show_result_pyplot

# 忽略警告
import warnings
warnings.filterwarnings('ignore')

display.clear_output()
  • 创建文件夹,用于放置数据集、模型预训练权重和模型推理输出
# 创建 checkpoint 文件夹,用于存放预训练模型权重文件
os.mkdir('checkpoint')

# 创建 outputs 文件夹,用于存放预测结果
os.mkdir('outputs')

# 创建 data 文件夹,用于存放图片和视频素材
os.mkdir('data')
  • 分别下载pspnet、segformer、mask2former在cityscapes上的预训练权重,并保存在checkpoint文件夹中
# 从Model Zoo预训练模型,下载并保存在 checkpoint 文件夹中
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth -P checkpoint
display.clear_output()
  • 下载一些测试模型用的图片以及视频,并存放到data文件夹中。
# 伦敦街景图片
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data

# 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
!wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data

# 街拍视频,2022年3月30日
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data
display.clear_output()

图片推理

命令行推理

  • 使用命令行对图片进行推理,并使用PIL对结果进行可视化
  • 分别使用了pspnet模型和segformer模型进行推理
# pspnet模型
!python demo/image_demo.py \
        data/street_uk.jpeg \
        configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
        checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
        --out-file outputs/B1_uk_pspnet.jpg \
        --device cuda:0 \
        --opacity 0.5

display.clear_output()
Image.open('outputs/B1_uk_pspnet.jpg')

请添加图片描述

# segformer模型
!python demo/image_demo.py \
        data/street_uk.jpeg \
        configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
        checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
        --out-file outputs/B1_uk_segformer.jpg \
        --device cuda:0 \
        --opacity 0.5
display.clear_output()
Image.open('outputs/B1_uk_segformer.jpg')

请添加图片描述

  • 可以看到其实segformer的效果比pspnet模型效果要好,基本上能将不同物体分割开。

API推理

  • 使用mmsegmentation的Python API进行图片推理
  • 使用mask2former模型推理,并利用matplotlib对结果进行可视化
img_path = 'data/street_uk.jpeg'
img_pil = Image.open(img_path)
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

model = init_model(config_file, checkpoint_file, device='cuda:0')

if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)

result = inference_model(model, img_path)
pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()

display.clear_output()
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(14, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.55) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('outputs/B2-1.jpg')
plt.show()

请添加图片描述

  • mask2former作为sota模型,效果确实非常棒!

视频推理

命令行推理

  • 不推荐,速度很慢
!python demo/video_demo.py \
        data/street_20220330_174028.mp4 \
        configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
        checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
        --device cuda:0 \
        --output-file outputs/B3_video.mp4 \
        --opacity 0.5

API推理

  • mask2former模型使用API对视频进行推理
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

model = init_model(config_file, checkpoint_file, device='cuda:0')

if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)

display.clear_output()

input_video = 'data/street_20220330_174028.mp4'

temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))

# 获取 Cityscapes 街景数据集 类别名和调色板
classes = cityscapes.CityscapesDataset.METAINFO['classes']
palette = cityscapes.CityscapesDataset.METAINFO['palette']

def pridict_single_frame(img, opacity=0.2):

    result = inference_model(model, img)

    # 将分割图按调色板染色
    seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
    seg_img = Image.fromarray(seg_map).convert('P')
    seg_img.putpalette(np.array(palette, dtype=np.uint8))

    show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacity

    return show_img

# 读入待预测视频
imgs = mmcv.VideoReader(input_video)

prog_bar = mmengine.ProgressBar(len(imgs))

# 对视频逐帧处理
for frame_id, img in enumerate(imgs):

    ## 处理单帧画面
    show_img = pridict_single_frame(img, opacity=0.15)
    temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' # 保存语义分割预测结果图像至临时文件夹
    cv2.imwrite(temp_path, show_img)

    prog_bar.update() # 更新进度条

# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)

小样本数据集微调mask2former

  • 在西瓜语义分隔数据集上对模型进行微调

下载数据集

!rm -rf Watermelon87_Semantic_Seg_Mask.zip Watermelon87_Semantic_Seg_Mask

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip

!unzip Watermelon87_Semantic_Seg_Mask.zip >> /dev/null # 解压

!rm -rf Watermelon87_Semantic_Seg_Mask.zip # 删除压缩包

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/watermelon_test1.jpg -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_2.mp4 -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_3.mov -P data

# 删除系统自动生成的多余文件
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'

# 删除多余文件
!for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
!for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
!for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done

# 验证多余文件已删除
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'

display.clear_output()

可视化探索语义分割数据集

  • 可视化语义信息
# 指定单张图像路径
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'

img = cv2.imread(img_path)
mask = cv2.imread(mask_path)

# 可视化原图叠加
plt.figure(figsize=(8, 8))
plt.imshow(img[:,:,::-1])
plt.imshow(mask[:,:,0], alpha=0.6) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.show()

请添加图片描述

定义Dataset和Pipeline

  • Dataset部分,可以设定数值对应的具体类别,以及不同类别的标注颜色。图像格式,是否忽略类别0
  • Pipeline部分,可以设定训练、验证的数据处理步骤。以及规定图像裁剪尺寸
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
    # 类别和对应的 RGB配色
    METAINFO = {
        'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],
        'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
"""

with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:
    f.write(custom_dataset)
  • custom_dataset加入__init__.py文件
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
                         BioMedical3DRandomCrop, BioMedical3DRandomFlip,
                         BioMedicalGaussianBlur, BioMedicalGaussianNoise,
                         BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
                         LoadBiomedicalAnnotation, LoadBiomedicalData,
                         LoadBiomedicalImageFromFile, LoadImageFromNDArray,
                         PackSegInputs, PhotoMetricDistortion, RandomCrop,
                         RandomCutOut, RandomMosaic, RandomRotate,
                         RandomRotFlip, Rerange, ResizeShortestEdge,
                         ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
__all__ = [
    'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
    'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
    'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
    'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
    'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
    'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
    'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
    'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
    'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
    'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
    'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
    'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
    'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
    'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
    'SynapseDataset', 'MyCustomDataset'
]

"""

with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:
    f.write(custom_init)
  • 定义数据集预处理通道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)

# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/train', seg_map_path='ann_dir/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/val', seg_map_path='ann_dir/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
"""

with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:
    f.write(custom_pipeline)

修改配置文件

  • 主要修改类别个数、预训练权重路径、初始化图片尺寸(一般为128的整数倍)、batch_size、缩放学习率(修改的比例是 base_lr_default * (your_bs / default_bs))、更改学习率衰减策略
  • 关于学习率:主要修改optimizer中的lr,不用修改optim_wrapper
  • 冻结模型的骨干网络,对mask2former来说可以加快训练
cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
# 类别个数
NUM_CLASS = 6
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size

# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4


# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloader


cfg.optimizer.lr = cfg.optimizer.lr / 8

# 结果保存目录
cfg.work_dir = './work_dirs'

cfg.train_cfg.max_iters = 4000 # 训练迭代次数
cfg.train_cfg.val_interval = 50 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 50 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

cfg.param_scheduler[0].end = cfg.train_cfg.max_iters
# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件
cfg.dump('custom_mask2former.py')
  • 开始训练
!python tools/train.py custom_mask2former.py
  • 选取最优模型,测试模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+------------+-------+-------+-------+--------+-----------+--------+
|   Class    |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+------------+-------+-------+-------+--------+-----------+--------+
| background | 98.55 | 99.12 | 99.27 | 99.27  |   99.42   | 99.12  |
|    red     | 96.54 | 98.83 | 98.24 | 98.24  |   97.65   | 98.83  |
|   green    | 94.37 | 96.08 |  97.1 |  97.1  |   98.14   | 96.08  |
|   white    | 85.96 | 92.67 | 92.45 | 92.45  |   92.24   | 92.67  |
| seed-black | 81.98 | 90.87 |  90.1 |  90.1  |   89.34   | 90.87  |
| seed-white | 65.57 | 69.98 | 79.21 | 79.21  |   91.24   | 69.98  |
+------------+-------+-------+-------+--------+-----------+--------+

可视化训练指标

在这里插入图片描述

肾小球数据集微调模型

  • 在单类别数据集(组织病理切片肾小球)上微调mask2former模型
  • 首先清空工作目录、data文件夹和outputs文件
# 清空工作目录
!rm -r work_dirs/*
# 清空data文件夹
!rm -r data/*
# 清空outputs文件夹
!rm -r outputs/*

可视化探索语义分割数据集

# 指定图像和标注路径
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'

mask = cv2.imread('/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png')
# 查看类别
np.unique(mask)
  • 输出
array([0, 1], dtype=uint8)
  • 可视化语义分割信息
# n行n列可视化
n = 5

# 标注区域透明度,透明度越小,越接近原图
opacity = 0.65

fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))

for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):
    
    # 载入图像和标注
    img_path = os.path.join(PATH_IMAGE, file_name)
    mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    
    # 可视化
    axes[i//n, i%n].imshow(img[:,:,::-1])
    axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)
    axes[i//n, i%n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=20)
plt.tight_layout()
plt.savefig('outputs/C2-1.jpg')
plt.show()

请添加图片描述

分割训练集与测试集

  • 新建各类训练、验证文件夹
# 新建图片训练、验证文件夹
!mkdir -p data/images/train
!mkdir -p data/images/val

# 新建mask训练、验证文件夹
!mkdir -p data/masks/train
!mkdir -p data/masks/val
  • 随机打乱数据,并按照90%训练集、10%测试集分割
def copy_file(og_images, og_masks, tr_images, tr_masks, thor):
    # 获取源文件夹中的所有文件名
    file_names = os.listdir(og_images)
    
    # 随机打乱文件名列表
    random.shuffle(file_names)
    
    # 计算分割点
    split_index = int(thor * len(file_names))
    
    # 复制训练集文件
    for file_name in file_names[:split_index]:
        og_image = os.path.join(og_images, file_name)
        og_mask = os.path.join(og_masks, file_name)
        tr_image = os.path.join(tr_images, 'train', file_name)
        tr_mask = os.path.join(tr_masks, 'train', file_name)
        shutil.copyfile(og_image, tr_image)
        shutil.copyfile(og_mask, tr_mask)

    # 复制验证集文件
    for file_name in file_names[split_index:]:
        og_image = os.path.join(og_images, file_name)
        og_mask = os.path.join(og_masks, file_name)
        tr_image = os.path.join(tr_images, 'val', file_name)
        tr_mask = os.path.join(tr_masks, 'val', file_name)
        shutil.copyfile(og_image, tr_image)
        shutil.copyfile(og_mask, tr_mask)

og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'

tr_images = 'data/images'
tr_masks = 'data/masks'

copy_file(og_images, og_masks, tr_images, tr_masks, 0.9)

重新定义Dataset和Pipeline

  • 主要是修改类别及对应RGB配色
  • 以及dataload的路径信息
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
    # 类别和对应的RGB配色
    METAINFO = {
        'classes':['normal','sclerotic'],
        'palette':[[127,127,127],[251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,img_suffix='.png',
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
"""

with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:
    f.write(custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
                         BioMedical3DRandomCrop, BioMedical3DRandomFlip,
                         BioMedicalGaussianBlur, BioMedicalGaussianNoise,
                         BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
                         LoadBiomedicalAnnotation, LoadBiomedicalData,
                         LoadBiomedicalImageFromFile, LoadImageFromNDArray,
                         PackSegInputs, PhotoMetricDistortion, RandomCrop,
                         RandomCutOut, RandomMosaic, RandomRotate,
                         RandomRotFlip, Rerange, ResizeShortestEdge,
                         ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
__all__ = [
    'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
    'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
    'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
    'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
    'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
    'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
    'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
    'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
    'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
    'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
    'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
    'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
    'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
    'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
    'SynapseDataset', 'MyCustomDataset'
]

"""

with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:
    f.write(custom_init)
  • 定义数据预处理管道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)

# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/train', seg_map_path='masks/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/val', seg_map_path='masks/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
"""

with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:
    f.write(custom_pipeline)

修改配置文件

cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
  • 更改配置文件
# 类别个数
NUM_CLASS = 2
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size

# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4


# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloader


cfg.optimizer.lr = cfg.optimizer.lr / 8

# 结果保存目录
cfg.work_dir = './work_dirs'

cfg.train_cfg.max_iters = 40000 # 训练迭代次数
cfg.train_cfg.val_interval = 500 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 2500 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件,并开始训练
cfg.dump('custom_mask2former.py')
!python tools/train.py custom_mask2former.py

可视化训练指标

在这里插入图片描述

评估模型以及测试推理速度

  • 评估模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+-----------+-------+-------+-------+--------+-----------+--------+
|   Class   |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+-----------+-------+-------+-------+--------+-----------+--------+
|   normal  | 99.74 | 99.89 | 99.87 | 99.87  |   99.86   | 99.89  |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71  |   93.57   | 91.87  |
+-----------+-------+-------+-------+--------+-----------+--------+
  • 测试模型推理速度
# 测试FPS
!python tools/analysis_tools/benchmark.py custom_mask2former.py '{best_pth}'
  • 输出:
Done image [50 / 200], fps: 2.24 img / s
Done image [100/ 200], fps: 2.24 img / s
Done image [150/ 200], fps: 2.24 img / s
Done image [200/ 200], fps: 2.24 img / s
Overall fps: 2.24 img / s

Average fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

利用MMSegmentation微调Mask2Former模型 的相关文章

  • 语义分割研究现状

    以语义分割热门的数据集Cityscapes的精度作为参考 比较当前语义分割网络效果 可以通过ICNet中的这张图来说明目前大多数方法的精度以及速度 目前MIOU超过80的有PSPNet ResNet38 PSPNet DUC 以及DANet
  • AutoDL跑pycharm代码

    参考文献 AutoDL帮助文档 Pycharm连接远程GPU服务器跑深度学习 哔哩哔哩 bilibili 环境包的安装在linux环境下载非常方便 安装apex 重点是将路径转换正确 参考文献 详解Apex的安装和使用教程 花开山岗红艳艳的
  • 【数据集】——SBD数据集下载链接

    简介 SBD Dataset 是一个语义边界数据集 其包含来自 PASCAL VOC 2011 数据集中 11355 张图片的注释 这些图片均基于 Amazon Mechanical Turk 其中分割之间的冲突均为手动解决 此外 每张图像
  • (CVPR2019)图像语义分割(18) DANet-集成双路注意力机制的场景分割网络

    论文地址 Dual Attention Network for Scene Segmentation 工程地址 github链接 1 介绍 该论文提出新型的场景分割网络DANet 利用自注意力机制进行丰富语义信息的捕获 在带有空洞卷积的FC
  • 基于深度学习的图像分割总结

    一 图像分割类别 随着深度学习的发展 在分割任务中出现了许多优秀的网络 根据实际分割应用任务的不同 可以大致将分割分为三个研究方向 语义分割 实例分割 全景分割 这三种分割在某种意义上是具有一定的联系的 语义分割 像素级别的语义分割 对图像
  • Fully Convolutional Adaptation Networks for Semantic Segmentation

    参考 论文解析之 Fully Convolutional Adaptation Networks for Semantic Segmentation 云 社区 腾讯云 论文网址 Fully Convolutional Adaptation
  • 动手学深度学习_全卷积网络 FCN

    全卷积网络 fully convolutional network FCN 顾名思义 网络中完全使用卷积而不再使用全联接网络 全卷积网络之所以能把输入图片经过卷积后在进行尺寸上的还原 就是利用转置卷积实现的 因此 输出的类别预测与输入图像在
  • COCO数据集的下载、介绍及如何使用(数据载入及数据增广,含代码)

    如何使用COCO数据集 COCO数据集可以说是语义分割等计算机视觉任务中应用较为广泛的一个数据集 具体可以应用到物体识别 语义分割及目标检测等方面 我是在做语义分割方面任务时用到了COCO数据集 但本文主要讲解的是数据载入方面 因此可以通用
  • 【语义分割】10、ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation

    出处 ICCV2021 文章目录 一 背景 二 动机 三 方法 3 1 整体过程 3 2 Image Level Context Module 3 3 Semantic Level Context Module 3 4 Loss Funct
  • 【计算机视觉

    文章目录 一 SqueezeNet 二 Inception v3 三 Visual Geometry Group 19 Layer CNN 四 MobileNetV1 五 Data efficient Image Transformer 六
  • 训练PyTorch模型遇到显存不足的情况怎么办

    在跑代码的过程中 遇到了这个问题 当前需要分配的显存在600MiB以下 RuntimeError CUDA out of memory Tried to allocate 60 00 MiB GPU 0 10 76 GiB total ca
  • 语义分割系列26-VIT+SETR——Transformer结构如何在语义分割中大放异彩

    SETR Rethinking Semantic Segmentation from a Sequence to Sequence Perspectivewith Transformers 重新思考语义分割范式 使用Transformer实
  • 遥感影像语义分割:数据集制作

    遥感影像语义分割 数据集制作 一 标签标注工具及注意事项 二 影像分块及代码 目前已经有一些已经关于遥感影像解译的公开数据集 我们可以拿这些数据来做深度学习模型训练 但是在实际业务中 我们需要针对特定的需求制作自己的数据集 在这里记录一下做
  • 语义分割算法汇总(长期更新)

    语义分割算法汇总 记录一下各类语义分割算法 便于自己学习 由DFANet Deep Feature Aggregation for Real Time Semantic Segmentation开始 在文章中 作者说明了在Cityscape
  • 【计算机视觉

    文章目录 一 PROMISE12 二 BraTS 2015 三 LIP Look into Person 四 BigEarthNet 五 Stanford Background Standford Background Dataset 六
  • 【语义分割】13、SegNeXt

    文章目录 一 背景 二 方法 2 1 Convolutional Encoder 2 2 Decoder 三 效果 论文 SegNeXt Rethinking Convolutional Attention Design for Seman
  • 憨批的语义分割1——基于Mobile模型的segnet讲解

    憨批的语义分割1 基于Mobile模型的segnet讲解 学习前言 什么是Segnet模型 segnet模型的代码实现 1 主干模型Mobilenet 2 segnet的Decoder解码部分 代码测试 学习前言 最近开始设计新的领域啦 语
  • MMSegmention官方文档阅读系列之三(MMSegmentation 算法库目录结构、了解配置文件信息)

    1 MMSegmentation 算法库目录结构的主要部分 1 mmsegmentation configs 配置文件 base 基配置文件 datasets 数据集相关配置文件 models 模型相关配置文件 schedules 训练日程
  • 图像分割2021

    cvpr2022总结 CVPR 2022 图像分割论文大盘点 大林兄的博客 CSDN博客 图像分割最新论文 尽管近年来实例分割取得了长足的进步 但如何设计具有实时性的高精度算法仍然是一个挑战 本文提出了一种实时实例分割框架OrienMask
  • 憨批的语义分割重制版2——语义分割评价指标mIOU的计算

    憨批的语义分割重制版2 语义分割评价指标mIOU的计算 注意事项 学习前言 什么是mIOU mIOU的计算 1 计算混淆矩阵 2 计算IOU 3 计算mIOU 计算miou 注意事项 这是针对重构了的语义分割网络 而不是之前的那个 所以不要

随机推荐

  • Aspose工具实现word和ppt转pdf功能及遇到的一些问题

    Aspose工具包从word和ppt转到pdf的实现过程 直接放项目地址 说一下实现过程中遇到的坑 直接放项目地址 https github com lichangliu1098 File2Pdf 说一下实现过程中遇到的坑 jar包的引入
  • Zabbix部署详细步骤

    以下是在Ubuntu上安装Zabbix的详细步骤 1 更新系统 使用以下命令更新Ubuntu系统 sudo apt get update sudo apt get upgrade 2 安装依赖项 在安装Zabbix之前 需要先安装一些依赖项
  • 教你快速上手Flex弹性盒布局(容器属性)

    目录 简介 一 Flex布局语法 1 1 display flex 二 Flex属性 三 容器属性 3 1 flex direction 3 2 flex direction row 3 3 flex direction row rever
  • 【华为OD统一考试B卷

    在线OJ 已购买本专栏用户 请私信博主开通账号 在线刷题 运行出现 Runtime Error 0Aborted 请忽略 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一修改为 2023A卷和2023B卷 你收
  • Tachyon内存文件系统

    Tachyon内存文件系统 Tachyon是以内存为中心的分布式文件系统 拥有高性能和容错能力 能够为集群框架 如Spark MapReduce 提供可靠的内存级速度的文件共享服务 从软件栈的层次来看 Tachyon是位于现有大数据计算框架
  • 使用GitHub的一些小知识合集

    文章目录 一 FastGithub 1 稳定可靠的github加速神器 FastGithub 2 github加速神器 FastGithub 二 README md添加图片 1 怎么给README md添加图片 两种方法 图文教程 2 在R
  • OpenCV代码提取:morphologyEx函数的实现

    Morphological Operations A set of operations that process images based on shapes Morphological operations apply a struct
  • keil软件安装与破解

    目录 目录 下载 安装 破解 下载 学习51单片机必要的开发工具是 KEIL C51 下载密码dsfs 安装 按照一般安装软件顺序即可 此处随便填写 破解 注意 需以管理员身份 运行 按照如下图片步骤 弹出如下对话框 再打开软件按如下操作
  • 解决pycharm错误:Error updating package list: connect timed out解决

    方法是在 Manage Repositories 中 修改数据来源 默认的是 https pypi python org simple 我们可将其替换为如下的几个数据来源 这些都是国内的pip镜像 清华 https pypi tuna ts
  • Vue3 emits选项将Emit派发事件可以对参数进行验证。

    Vue官方建议我们在组件中所有的emit事件都能在组件的emits选项中声明 emits参数有俩种形式对象和数组 对象里面可以配置带校验emit事件 为null的时候代表不校验 校验的时候 会把emit事件的参数传到校验函数的参数里面 当校
  • Python第三方库之MedPy

    文章目录 1 MedPy简介 2 MedPy安装 3 MedPy常用函数 3 1 medpy io load image 3 2 medpy metric binary dc result reference 3 3 medpy metri
  • < Linux >:环境变量

    目录 环境变量 常见的环境变量 基本概念 查看环境变量内容的方法 测试环境变量PATH 与环境变量相关的命令 Linux操作系统下C C 程序代码中获取环境变量的方式 环境变量的组织方式 环境变量通常具有全局属性 环境变量 问题 注意 可执
  • JavaScript基础语言

    1 JavaScript采用Unicode字符集编写的 区分大小写 但HTML不区分大小写 与JavaScript同名的标签和属性 可以大写也可以小写 2 JavaScript存在两种形式的注释 行尾注释 和 多行注释 3 标识符就是一个名
  • 光束平差法(Bundle Adjust)

    光束平差法 代价函数 代价函数求解 Levenberg Marquardt方法 代码实现 流程图 光束平差法 采用光束平差法对射影空间下的多个相机运动矩阵及非编码元三维结构进行优化 光束平差法一般在各种重建算法的最后一步使用 这种优化方法的
  • 虚幻官方项目《CropOut》技术解析 之 程序化岛屿生成器(IslandGenerator)

    开个新坑详细分析一下虚幻官方发布的 CropOut 文章会同步发布到我在知乎 CSDN的专栏里 文章目录 概要 Create Island 几何体生成部分 随机种子 Step 1 Step 2 Step 3 Step 4 Step 5 St
  • Python—爬虫之BeautifulSoup模块(解析—提取数据)

    Python 爬虫之BeautifulSoup模块 解析 提取数据 安装BeautifulSoup模块 解析数据 提取数据 find 提取出满足条件的第一个数据 find all 提取出满足条件所有数据 Tag 对象的属性和方法 安装Bea
  • .NET混淆器 Dotfuscator保护机制——重命名

    Dotfuscator是一个 NET的Obfuscator 它提供企业级的应用程序保护 大大降低了盗版 知识产权盗窃和篡改的风险 Dotfuscator的分层混淆 加密 水印 自动失效 防调试 防篡改 报警和防御技术 为世界各地成千上万的应
  • Streamlit 讲解专栏(十二):数据可视化-图表绘制详解(下)

    文章目录 1 前言 2 使用st vega lite chart绘制Vega Lite图表 2 1 示例1 绘制散点图 2 2 示例2 自定义主题样式 3 使用st plotly chart函数创建Plotly图表 3 1 st plotl
  • WebSocket(一) -- 原理详解

    1 什么是websocket WebSocket是HTML5下一种新的协议 websocket协议本质上是一个基于tcp的协议 它实现了浏览器与服务器全双工通信 能更好的节省服务器资源和带宽并达到实时通讯的目的 Websocket是一个持久
  • 利用MMSegmentation微调Mask2Former模型

    前言 本文介绍了专用于语义分隔模型的python库mmsegmentation github项目地址 运行环境为Kaggle notebook GPU为P100 针对环境配置 预训练模型推理 在西瓜数据集上微调新sota模型mask2for