学习Segformer语义分割模型并训练测试cityscapes数据集

2023-10-27

官方的segformer源码是基于MMCV框架,整体包装较多,自己不便于阅读和学习,我这里使用的是Bubbliiiing大佬github复现的segformer版本。
Bubbliiiing大佬代码下载链接:

https://github.com/bubbliiiing/segformer-pytorch

大佬的代码很优秀简练,注释也很详细,代码里采用的是VOC数据集的格式,因此只需要把cityscapes数据格式修改即可。

一、Segformer模型结构

Segformer特点:transformer + 特征融合 + 轻量级MLP + 选择3*3卷积并舍弃位置编码
在这里插入图片描述

1.OverlapPatchEmbed模块

分割输入图像,使用卷积操作将输入图像分成大小为 patch_size 的块,并使用步幅为 stride 移动这些块以创建重叠块。然后对每个块进行一维向量化,并通过标准化层进行标准化。该模块的输出包含一个形状为 (B, N, C) 的张量,其中 B 是批大小,N 是每个块中像素数量的数量,C 是嵌入维度。此外,该模块还返回 H W,这是输入图像的大小,因为在解码时需要了解原始图像的大小。

class OverlapPatchEmbed(nn.Module):
    def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size  = (patch_size, patch_size) #7*7
        self.proj   = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm   = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W

2.Self-Attention模块

关于Self-Attention原理,可以去看这个大佬的文章,讲的很详细:https://zhuanlan.zhihu.com/p/410776234
核心为这个公式:在这里插入图片描述Segformer中做了些改进。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim        = dim
        self.num_heads  = num_heads
        head_dim        = dim // num_heads
        self.scale      = qk_scale or head_dim ** -0.5

        self.q          = nn.Linear(dim, dim, bias=qkv_bias)
        
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr     = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm   = nn.LayerNorm(dim)
        self.kv         = nn.Linear(dim, dim * 2, bias=qkv_bias)
        
        self.attn_drop  = nn.Dropout(attn_drop)
        
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

        self.apply(self._init_weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        # bs, 16384, 32 => bs, 16384, 32 => bs, 16384, 8, 4 => bs, 8, 16384, 4
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            # bs, 16384, 32 => bs, 32, 128, 128
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            # bs, 32, 128, 128 => bs, 32, 16, 16 => bs, 256, 32
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            # bs, 256, 32 => bs, 256, 64 => bs, 256, 2, 8, 4 => 2, bs, 8, 256, 4
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        # bs, 8, 16384, 4 @ bs, 8, 4, 256 => bs, 8, 16384, 256 
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # bs, 8, 16384, 256  @ bs, 8, 256, 4 => bs, 8, 16384, 4 => bs, 16384, 32
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # bs, 16384, 32 => bs, 16384, 32
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

3.MixFFN模块

在这里插入图片描述
不同于VIT,segformer舍弃了位置编码,使用3x3的卷积构建了MixFFN模块。

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features    = out_features or in_features
        hidden_features = hidden_features or in_features
        
        self.fc1    = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act    = act_layer()
        
        self.fc2    = nn.Linear(hidden_features, out_features)
        
        self.drop   = nn.Dropout(drop)

        self.apply(self._init_weights)
        
    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

4.拼接融合与MLP解码

这部分就是把前面的拼接然后输出。
在这里插入图片描述

    def forward(self, inputs):
        c1, c2, c3, c4 = inputs

        ############## MLP decoder on C1-C4 ###########
        n, _, h, w = c4.shape
        
        _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
        _c4 = F.interpolate(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)

        _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
        _c3 = F.interpolate(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)

        _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
        _c2 = F.interpolate(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)

        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])

        _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))

        x = self.dropout(_c)
        x = self.linear_pred(x)

        return x

二、cityscapes代码修改

1.数据集文件夹格式

在这里插入图片描述
这里数据集标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。
因此,使用cityscapes几个标签中的 _labelIds.png标签。

数据集划分按自己需求修改voc_annotation.py文件
在这里插入图片描述

2.修改dataloader.py文件

原本的这个标签中的类别是0到33和-1,我做的19类别分割。修改dataloader.py文件:

我这里直接复制了我之前使用过的encode_target内容加入进去:

CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])
    classes = [
        CityscapesClass('unlabeled',            0, 19, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('ego vehicle',          1, 19, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('rectification border', 2, 19, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('out of roi',           3, 19, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('static',               4, 19, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('dynamic',              5, 19, 'void', 0, False, True, (111, 74, 0)),
        CityscapesClass('ground',               6, 19, 'void', 0, False, True, (81, 0, 81)),
        CityscapesClass('road',                 7, 0, 'flat', 1, False, False, (128, 64, 128)),
        CityscapesClass('sidewalk',             8, 1, 'flat', 1, False, False, (244, 35, 232)),
        CityscapesClass('parking',              9, 19, 'flat', 1, False, True, (250, 170, 160)),
        CityscapesClass('rail track',           10, 19, 'flat', 1, False, True, (230, 150, 140)),
        CityscapesClass('building',             11, 2, 'construction', 2, False, False, (70, 70, 70)),
        CityscapesClass('wall',                 12, 3, 'construction', 2, False, False, (102, 102, 156)),
        CityscapesClass('fence',                13, 4, 'construction', 2, False, False, (190, 153, 153)),
        CityscapesClass('guard rail',           14, 19, 'construction', 2, False, True, (180, 165, 180)),
        CityscapesClass('bridge',               15, 19, 'construction', 2, False, True, (150, 100, 100)),
        CityscapesClass('tunnel',               16, 19, 'construction', 2, False, True, (150, 120, 90)),
        CityscapesClass('pole',                 17, 5, 'object', 3, False, False, (153, 153, 153)),
        CityscapesClass('polegroup',            18, 19, 'object', 3, False, True, (153, 153, 153)),
        CityscapesClass('traffic light',        19, 6, 'object', 3, False, False, (250, 170, 30)),
        CityscapesClass('traffic sign',         20, 7, 'object', 3, False, False, (220, 220, 0)),
        CityscapesClass('vegetation',           21, 8, 'nature', 4, False, False, (107, 142, 35)),
        CityscapesClass('terrain',              22, 9, 'nature', 4, False, False, (152, 251, 152)),
        CityscapesClass('sky',                  23, 10, 'sky', 5, False, False, (70, 130, 180)),
        CityscapesClass('person',               24, 11, 'human', 6, True, False, (220, 20, 60)),
        CityscapesClass('rider',                25, 12, 'human', 6, True, False, (255, 0, 0)),
        CityscapesClass('car',                  26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
        CityscapesClass('truck',                27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
        CityscapesClass('bus',                  28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
        CityscapesClass('caravan',              29, 19, 'vehicle', 7, True, True, (0, 0, 90)),
        CityscapesClass('trailer',              30, 19, 'vehicle', 7, True, True, (0, 0, 110)),
        CityscapesClass('train',                31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
        CityscapesClass('motorcycle',           32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
        CityscapesClass('bicycle',              33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
        CityscapesClass('license plate',        -1, 19, 'vehicle', 7, False, True, (0, 0, 142)),
    ]

    id_to_train_id = np.array([c.train_id for c in classes])
    def encode_target(cls, png):
        return cls.id_to_train_id[np.array(png)]

同时修改def getitem(self, index)函数:
修改一下split,原本的voc的标签和图像名称一样,加个image_name,然后加个png = self.encode_target(png)


    def __getitem__(self, index):
        annotation_line = self.annotation_lines[index]
        name            = annotation_line.split()[0]

        #-------------------------------#
        #   从文件中读取图像
        #-------------------------------#
        image_name = annotation_line.split('_gtFine_labelIds')[0] + '_leftImg8bit'
        jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), image_name + ".png"))
        #jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".png"))
        png         = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))
        #-------------------------------#
        #   数据增强
        #-------------------------------#
        jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = self.train)

        jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
        png         = np.array(png)
        png = self.encode_target(png)
        #png[png >= self.num_classes] = self.num_classes
        #-------------------------------------------------------#
        #   转化成one_hot的形式
        #   在这里需要+1是因为voc数据集有些标签具有白边部分
        #   我们需要将白边部分进行忽略,+1的目的是方便忽略。
        #-------------------------------------------------------#
        seg_labels  = np.eye(self.num_classes + 1)[png.reshape([-1])]
        seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))

        return jpg, png, seg_labels

3.获取RGB预测图

get_miou.py文件中生成的图为灰度图,看到的是近似全黑的图。如果想要预测出来的是RGB图,将预测结果中每个像素的类别转换成RGB颜色值。因此另外写了一个映射函数,通过定义一个颜色映射表,将每个类别映射到一个RGB颜色值,并输出保存。

def CityscapesLABELtoRGB():
    # 定义RGB颜色映射关系
    color_map = {
        0: [128, 64, 128],
        1: [244, 35, 232],
        2: [70, 70, 70],
        3: [102, 102, 156],
        4: [190, 153, 153],
        5: [153, 153, 153],
        6: [250, 170, 30],
        7: [220, 220, 0],
        8: [107, 142, 35],
        9: [152, 251, 152],
        10: [70, 130, 180],
        11: [220, 20, 60],
        12: [255, 0, 0],
        13: [0, 0, 142],
        14: [0, 0, 70],
        15: [0, 60, 100],
        16: [0, 80, 100],
        17: [0, 0, 230],
        18: [119, 11, 32],
        19: [0, 0, 0]
    }

    # 加载类别标签图像
    label_path = "miou_out/detection-results"
    rgb_folder_path = "RGB"
    for file_name in os.listdir(label_path):
        # 加载类别标签图像
        rgb_path = os.path.join(label_path, file_name)
        img = Image.open(rgb_path)
        label_arr = np.array(img)

        # 将类别标签转换为RGB标签
        rgb_arr = np.zeros((label_arr.shape[0], label_arr.shape[1], 3), dtype=np.uint8)
        for key, value in color_map.items():
            rgb_arr[label_arr == key] = value

        # 将RGB标签保存为PNG图像

        rgb_path = os.path.join(rgb_folder_path, os.path.splitext(file_name)[0] + ".png")
        label_img = Image.fromarray(rgb_arr)
        label_img.save(rgb_path, "PNG", quality=100, bitdepth=8)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

学习Segformer语义分割模型并训练测试cityscapes数据集 的相关文章

随机推荐

  • 3.2面向对象

    面向对象与面向过程的区别 面向过程指的是工作过程如何执行 而面向对象指的是工作该让谁来完成 面向对象三大思想 OOA 面向对象分析 Object Oriented Analysis OOD 面向对象设计 Object Oriented De
  • 剑指 Offer 15. 二进制中1的个数(java+python)

    编写一个函数 输入是一个无符号整数 以二进制串的形式 返回其二进制表达式中数字位数为 1 的个数 也被称为 汉明重量 提示 请注意 在某些语言 如 Java 中 没有无符号整数类型 在这种情况下 输入和输出都将被指定为有符号整数类型 并且不
  • 一万字关于java数据结构堆的讲解,让你从入门到精通

    目录 java类和接口总览 队列 Queue 1 概念 2 队列的使用 以下是一些常用的队列操作 1 入队操作 2 出队操作 3 判断队列是否为空 4 获取队列大小 5 其它 优先级队列 堆 1 优先级队列概念 Java中的Priority
  • uvicorn启动fastapi项目实现热部署

    一 创建在不同的文件中 通过import引入 将app对象定义在一个模块中 如app py 在主模块 如main py 中导入app并运行 app py from fastapi import FastAPI app FastAPI mai
  • 真题详解(关系模型)-软件设计(六十六)

    真题详解 ICMP 软件设计 六十五 https blog csdn net ke1ying article details 130475620 2017年下半年 内存按字节编址 若存储容量为32K 8bit的存储芯片构成地址从A0000H
  • C++基础知识 - 子类的析构函数

    子类的析构函数 注意 为了防止内存泄露 最好是在基类析构函数上添加virtual关键字 使基类析构函数为虚函数 目的在于 当使用delete释放基类指针时 会实现动态的析构 如果基类指针指向的是基类对象 那么只调用基类的析构函数 如果基类指
  • 【DSP】TMS320F28335的ADC模块

    一 功能说明 12位内建采样保持的模数转换器 模拟输入电平 0 3V 16个转换通道 最快转换时钟频率12 5MHz 奈奎斯特定则 25MHz最高能采样12 5MHz的信号 多触发源 软件 ePWM和GPIO 两种采样模式 级联和双通道模式
  • 张大哥笔记-从零开始自己创建一个网站的操作指南

    随着互联网时代的发展 无论是个人还是企业 都想拥有一个自己的网站 通过网站快速展示自己的商品信息 有很多人不了解一个网站是如何形成的 制作一个网站需要多少时间 具体由哪些细节都是全然不知 他们甚至感觉搭建一个网站是一件非常复杂的事情 其实
  • C语言经典100例题(37)--给10个数排序(选择排序和冒泡排序)

    目录 题目 问题分析 选择排序法 冒泡排序法 代码 选择排序法 冒泡排序法 运行结果 题目 给10个数排序 问题分析 选择排序法 从后9个比较过程中 选择一个最小的与第一个元素交换 下次类推 即用第二个元素与后8个进行比较 并进行交换 每一
  • Appium + IOS 自动化环境搭建教程(实践+总结+踩坑)

    文章目录 前言 IOS 自动化相关框架介绍 自动化测试类工具 内测发布工具 Appium驱动IOS测试原理 关于 WebDriverAgent 基础环境搭建 基础环境 安装内容 前提环境 通用环境 iOS 环境 iOS 真机调试环境配置 I
  • Pandas学习之to_csv()

    用法 df to csv 输出路径 参数1 参数2 参数3 sep 以逗号 作为数据的分隔符 如果分隔符不为逗号 则包含符双引号 就会消失 分隔符为Tab时 写法如下 df to csv new csv sep t na rep NA 确实
  • if-else常见的优化方案

    前言 代码中如果if else比较多 阅读起来比较困难 维护起来也比较困难 很容易出bug 接下来 本文将介绍优化if else代码的八种方案 优化方案一 提前return 去除不必要的else 如果if else代码块包含return语句
  • HBase 比较过滤器

    一 行过滤器 RowFilter 测试RowFilter过滤器 Test public void testRowFilter throws IOException Configuration conf HBaseConfiguration
  • echars-all.js 2.2.7组织结构图及自定义右键菜单的实现思路及源码

    组织结构图一般是树图结构 echars是一个很好的开源数据工具 2 x版本也有对树图的定义 要做组织结构图的需求拿到手里后 在网上也翻阅了很多echarts的官网资料及网友的实现思路 最终在网友及自身的努力下完成了这项任务 以上 先上最终结
  • Linux下gdb attach的使用(调试已在运行的进程)

    在Linux上 执行有多线程的程序时 当程序执行退出操作时有时会遇到卡死现象 如果程序模块多 代码量大 很难快速定位 此时可试试gdb attach方法 测试代码main cpp如下 这里为了使程序退出时产生卡死现象 在第51行时push线
  • 服务器连接超时是怎么回事呢?

    服务器连接超时就是在程序默认的等待时间内没有得到服务器的响应 网络连接超时可能的原因有那些呢 1 网络断开 不过经常显示无法连接 网络阻塞 导致你不能在程序默认等待时间内得到回复数据包 2 网络不稳定 网络无法完整传送服务器信息 系统问题
  • MySQL服务状态查看和监控方式

    运行中的MySQL状态查看可以通过以下方式进行监控和衡量不同指标 show global status like Uptime 服务器已运行的秒数 重启服务器后会置零 以下内容使用的Uptime均为此字段 1 QPS 每秒查询量 查询命令
  • WEB工程师和设计师必学的10个IOS 8新鲜改变

    概述 简介 iOS 8 上的 Safari 的更新 iPhone 6 和 iPhone 6 Plus 新 Api 支持 Safari 新功能和支持 iOS 8 原生优化 Safari 插件 新的设计 视频增强 iOS 8上的JS Bug 和
  • 【数学模型】基于Matlab模拟超市排队系统

    1 内容介绍 日常生活的超市购物中往往在收银台处排队等待 收银台过少 则会造成排队过长 收银台过多 又占用过多超市空间资源和人力资源 将整个超市购物过程所耗费的时间可抽象为若干个随机过程和排队系统 其中 消费者到达超市的时间是随机的 消费者
  • 学习Segformer语义分割模型并训练测试cityscapes数据集

    官方的segformer源码是基于MMCV框架 整体包装较多 自己不便于阅读和学习 我这里使用的是Bubbliiiing大佬github复现的segformer版本 Bubbliiiing大佬代码下载链接 https github com