【CV学习笔记】onnx篇之DETR

2023-10-27

1、摘要

本次学习内容主要学习了DETR的网络结构、损失函数等知识,明白了DETR是如何做到了端到端的检测,确实是一个十分优雅的框架,同时将DETR利用onnxtime进行推理,对于transformer的理解进一步加深了。

DETR学习链接:https://www.bilibili.com/video/BV1GB4y1X72R?spm_id_from=333.337.search-card.all.click

DETR官方地址:https://github.com/facebookresearch/detr

个人学习地址:https://github.com/Rex-LK/tensorrt_learning
欢迎正在学习或者想学的CV的同学进群一起讨论与学习,v:Rex1586662742,q群:468713665

2、DETR

2.1、简介

DETR是transformer在目标检测领域内的里程碑式的工作,主要实现了端到端的目标检测,避免了计算anchor和nms操作,其网络结构也十分直接明了,下图为论文中详细绘制的DETR网络结构图,大致可以分为如下四个步骤:

1、利用CNN提取特征图

2、encoder用于学习全局的特征

3、decoder生成预测框

4、训练时,采用二分图匹配的方式将ground truth框和预测框做匹配,并计算loss。预测时,直接将第三步生成的预测框中阈值低于0.7的过滤掉。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y61oGTQH-1655255959113)(Screenshot%20from%202022-06-12%2020-48-11.png)]

原文中图片的输入尺寸是3×800×1066,通过卷积提取特征之后得到了2048×25×34的特征图,特征层尺寸为原图的1/32,然后将2048个特征层映射到256,变成256×25×34,同时positional ecodeing 维度也为256×25×34,位置编码与特征层相加之后,然后将25*34展平,最后得到850×256的特征向量,输入到transformer中,然后通过6个encoder后得到850×256的全局特征,然后输入到decoder中。

在decoder中加入了 object queries,是一个可学习的向量,维度为100×256,其中100代表预测100个预测,然后再将每层的 object querries与 每层的850×256 特征层反复做自注意力操作,就是将object querries 当做querry,将每层decoder得到的输出作为key,最终得到了一个100*×256的特征,然后利用FFN预测出物体的类别以及xywh,利用预测的100个框和ground truth 做最优匹配,采用匈牙利算法计算最后的目标函数。

其中在decoder中第一层没有object quireies,后面五层才有,主要是为了移除冗余的框,在object quireies通信之后,就可以知道其他每个query预测出什么框,然后尽量不要去重复这个框,似的最后不需要进行nms操作。在最后计算loss的时候,为了加速收敛并训练的更稳定,在每一个decoder后(共6个)都加了auxiliary loss。

下面通过论文中给出的推理代码

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
	def __init__(self, num_classes, hidden_dim, nheads,num_encoder_layers, num_decoder_layers):
		super().__init__()
        #resnet50提取图片特征
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        #将2048个特征层映射到256
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        #encoder and decoder
        self.transformer = nn.Transformer(hidden_dim, nheads,num_encoder_layers, num_decoder_layers)
        #类别预测
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        #框的预测
        self.linear_bbox = nn.Linear(hidden_dim, 4)
        #object_queries 100×256
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
        #位置编码
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
    
    def forward(self, inputs):
        #第一步提取特征
        #3*800*1066 -> 2048×25×34
        x = self.backbone(inputs)
        
        #256×25×34
        h = self.conv(x)
        H, W = h.shape[-2:]
        # 位置编码
        pos = torch.cat([
        	self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
        	self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))
        # 100×256的特征
        return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1066)
logits, bboxes = detr(inputs)

2.2、facebook官方源码学习

拿到官方源码的时候其实是很懵的,不知道从哪下手,后来经过一段时间的摸索,可以在我的仓库中找到mypredict.py,来实现一个简单的推理。

 #初始化模型
 detr = detr_resnet50()

其中模型初始化调用的_make_detr这个函数,其中backbone采用的resnet50,注意这里输出的特征层的尺寸大小为2048×50×67,是原文的2倍.

def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False):
    hidden_dim = 256
    backbone = Backbone(backbone_name, train_backbone=False, return_interm_layers=mask, dilation=dilation)
    pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
    backbone_with_pos_enc = Joiner(backbone, pos_enc)
    backbone_with_pos_enc.num_channels = backbone.num_channels
    transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True)
    detr = DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100)
    if mask:
        return DETRsegm(detr)
    return detr

在Transformer中定义了encoder和decoder

#定义一层encoder
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                        dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
#定义六层encoder
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

#定义一层decoder
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                        dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
#定义六层decoder
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                  return_intermediate=return_intermediate_dec)

然后在DETR中定义了一些基本组件

self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
self.backbone = backbone
self.aux_loss = aux_loss

初始化模型之后,接着就是推理过程了

    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)
        #利用resnet50提取特征 3×800×1066 -> 2048×50×67
        src, mask = features[-1].decompose()
        assert mask is not None
        #然后经过encoder和decoder得到100*256的预测值
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
        #类别以及bbox
        #取第六个decoder的结果
        outputs_class = self.class_embed(hs)[-1]
        outputs_coord = self.bbox_embed(hs).sigmoid()[-1][0]
        ...

其中self.transformer的计算代码如下

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        #将后两个维度展平 
        # [3350, 1, 256]
        src = src.flatten(2).permute(2, 0, 1)
        # 位置编码 [3350, 1, 256]
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        #decoder中的object queries
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        mask = mask.flatten(1)
        tgt = torch.zeros_like(query_embed)
        #memory  [3350, 1, 256]
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        #hs [6, 100, 1, 256] 六个decoder预测结果,预测时,取第六个decoder的结果
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

查看其中decoder的代码,主要是在TransformerDecoderLayer这个类中

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        # 对应每除了第一个decoder,其余每个decoder都与objectquerries进行注意力计算
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

通过上述代码,对DETR的推理过程有了一个比较直观的了解,总的来说,推理过程十分简洁,无非还是分为如下四个步骤

1、cnn提取图像特征

2、encoder提取全局特征

3、decoder生成预测框

4、筛选预测框

2.3、onnxruntime

2.3.1 export_onnx

在demo/detr-mian/mypredict.py中,包含了导出onnx以及onnx-simplify的方法,为了将部分后处理代码放到onnx中,对后处理代码进行了如下改写

hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
outputs_class = self.class_embed(hs)[-1]
outputs_coord = self.bbox_embed(hs).sigmoid()[-1][0]
probas = outputs_class.softmax(-1)[0, :, :-1]
pred = torch.cat((probas,outputs_coord),1)
pred = pred.unsqueeze(0)
return pred

通过netron来查看导出的onnx是否存在问题,发现导出的onnx的输出维度为1×100×95,为100个框的类别以及xywh,说明没有问题

在这里插入图片描述

导出onnx后,可以使用onnxruntime来检测导出onnx的正确性,运行infer-onnxruntime.py,如果结果与mypredict.py显示的结果一致,那么就说明导出的onnx正确。

if __name__ == "__main__":

    data_transform = transforms.Compose([
        transforms.Resize(800),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    img_path = "demo.jpg"
    img_o = Image.open(img_path)
    img = data_transform(img_o).unsqueeze(0)
    image_input = img.numpy()
    session = onnxruntime.InferenceSession("detr_sim.onnx", providers=["CPUExecutionProvider"])
    pred = session.run(["predict"], {"image": image_input})[0]
    scores = torch.from_numpy(pred[0][:,0:91])
    bboxes = torch.from_numpy(pred[0][:,91:])
    keep = scores.max(-1).values > 0.7
    scores = scores[keep]
    bboxes = bboxes[keep]
    print(bboxes)
    fin_bboxes = rescale_bboxes(bboxes, img_o.size)
    plot_results(img_o, scores, fin_bboxes)

可以看出,detr的后处理方式还是很简单的,由于这里转engine还有些许问题,等之后解决了这个问题之后,再进行tensorrt加速,用一张图来看看预测效果把。

在这里插入图片描述

3、总结

本次学习了detr的网络结构,了解了端到端的预测机制,阅读了detr的源码,受益匪浅,对transformer有了进一步的了解,只是遗憾的是暂时没能进行tensorrt加速,后续希望能解决这个问题。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【CV学习笔记】onnx篇之DETR 的相关文章

  • 最大比例

    X星球的某个大奖赛设了M级奖励 每个级别的奖金是一个正整数 并且 相邻的两个级别间的比例是个固定值 也就是说 所有级别的奖金数构成了一个等比数列 比如 16 24 36 54 其等比值为 3 2 现在 我们随机调查了一些获奖者的奖金数 请你
  • 面试题深入思考01-----Arrays.sort()与Collections.sort()

    面试题深入思考01 Arrays sort 与Collections sort 1 Collections sort Collections本质是关于集合的一种工具类 其中包含对集合的各种api 例如排序 反转 交换和复制等 其中sort方
  • word怎么恢复保存前的文件,word文件恢复

    我们在使用word编辑文档时偶尔会有误删除文档的经历 word要怎么恢复保存前的文件呢 本文为你提供了五种解决思路 你可以通过搜索word文档的备份文档 自动恢复文件 临时文件 回收站 第三方数据恢复软件找到文档 方法一 搜索 Word 备
  • katex

    Katex Accents Accent functions inside text Delimiters Delimiter Sizing Environments Letters and Unicode Other Letters Un

随机推荐

  • Android ----蓝牙架构

    蓝牙 1 fromwork 2 service 3 driver Bluetooth apk bluedroid 芯片厂家 fromwork到service直接调用 service到driver利用service调用 fromwork到dr
  • 【机器学习 - 4】:线性回归算法

    文章目录 线性回归 线性回归的理解 损失函数 简单线性回归 封装线性回归算法 线性回归算法 在sklearn中调用线性回归算法 向量化运算 线性回归模型中的误差 均方误差 MSE 均方根误差 平均绝对误差 调用sklearn中的均方根误差和
  • 位置式和增量式PID控制

    PID控制是一个二阶线性控制器 定义 通过调整比例 积分和微分三项参数 使得大多数的工业控制系统获得良好的闭环控制性能 优点 a 技术成熟 b 易被人们熟悉和掌握 c 不需要建立数学模型 d 控制效果好 e 鲁棒性 通常依据控制器输出与执行
  • 测试管理之测试过程

    测试过程 以此文来阐述自己对于测试过程的认识 目录 文章目录 目录 过程分类 测试过程主要分为测试前 测试中 测试完成 发布后 测试前 测试前注意事项 需求评审 参与评审 了解需求背景 需求详情以及需求价值 初步评估需求覆盖面 需求测试工作
  • CSS背景:背景色/背景图像/背景重复/背景附着/简写背景属性(一文搞懂)

    目录 CSS背景 CSS 背景色 实例 其他元素 实例 不透明度 透明度 实例 使用 RGBA 的透明度 实例 CSS 背景图像 实例 实例 实例 CSS 背景重复 实例 实例 CSS background repeat no repeat
  • 程序员版孔乙己

    互联网的格局 是和别处不同的 都是格子衫 稀疏的头发 双肩包 男 写代码的人 傍午傍晚散了工 每每三两人 背着手 沿着软件园溜达一圈 倘肯花点钱 便还会走到星巴克 买一杯咖啡 那样便能再多摸几分钟的鱼 我从十九岁起 便在软件园的星巴克打工
  • Vim 键盘贴纸(打印用)

    Vim是一个类似于Vi的著名的功能强大 高度可定制的文本编辑器 在Vi的基础上改进和增加了很多特性 vim学习过程中需要记住好多键位的使用 下面分享一下vim键位图 打印后贴在键盘上 原版下载地址 密码 ultv 注 作者纯手工打造 难免有
  • 数据相关知识点(数据资产、业务底座、业务中台、企业数仓、即席查询)

    业务底座 企业数仓所提供的支撑能力 业务中台 企业在经营过程中积累起来的 具有一定规模的且能够快速适应变化 能够支撑器企业数字化转型升级的能力 企业数仓 又名企业数据仓库 是一个面向主题的 集成的 非易失的且随时间变化的数据集合 用来支持管
  • 微信小程序自定义弹窗实现详解(可通用)

    本文为自定义弹窗 该内容可满足如下需求 自定义各种布局弹窗 点击弹窗布局外消失弹窗 弹出弹窗时背景阴影半透明 各方向弹出效果 本文为自下而上弹出 wxml 文件中 直接放到wxml的最底部就行了 十分简练
  • Unity之实现拖拽UI功能

    一 unity 图片切割 先把图片导入到Unity中 选中图片你会看到上边的Inspector界面 然后 选择Texture Type类型为Advanced 将Read Write Enabled选上 然后Sprite Mode选择Mult
  • [ Matlab ] 遗传算法求最短路径

    打包下载源代码 实例描述 配送中心数为 1 客户数 k 为 8 车辆总数 m 为 2 车辆载重皆为 8 吨 各客户点需求为 g i 1 2 8 单位为吨 已知客户点与配送中心的距离如表 1 其中 0 表示中心仓库 要求合理安排车辆的运输路线
  • 服务计算——web 技术 - 处理 Request 与 Response

    基于Negroni框架的cloudgo应用 本次实验是基于Negroni框架的应用 我设计了一个简单的四则运算应用 这个应用设计主要分为两部分 中间件设计 以及 main函数的设计 接下来就分别对这两个部分进行介绍 中间件设计 printF
  • VS2017配置Qt开发环境

    VS2017配置Qt开发环境 安装Qt5 12 11 安装Qt插件 在VS2017中进行设置 参考教程 安装Qt5 12 11 安装Qt插件 在VS2017中进行设置 参考教程 Qt下载地址 https download qt io Qt安
  • 前端拖拽自动生成代码_我的前端布局自动化——开始,布局自动生成(一)

    在web前端工作的这些年 历经多次技术变革 不过依然只是一个追随大牛们的小白 此时此刻 尤其加班时 最想做的就是干掉自己职业的东西 做一个可代替前端工作的工具 由此开始了前端自动化探索 这篇文章就是自己的旗子吧 先举起一面变革之旗 不论自己
  • 关于Number.toFixed()的总结

    关于Number toFixed 函数的总结 前言 今天工作中遇到了一个需求 需要将类似于 1 99999 这样的数字格式化为 2 00 这样的两位小数 本来打算自己实现一个类似的功能函数 但是没想到看起来容易 实际实现起来却还是有点复杂的
  • Windows server 2016 云主机创建虚拟机

    Windows server 2016 云主机创建虚拟机 Hyper V 安装失败 处理器没有所需要的虚拟化功能 vmvare workstation play 17 安装 vmvare ok https customerconnect v
  • 【毕设教程】深度学习经典网络 CNN模型:ResNet

    文章目录 0 简介 1 ResNet 介绍 2 深度网络的退化问题 3 残差学习 4 ResNet的网络结构 5 ResNet的TensorFlow实现 6 最后 0 简介 Hi 大家好 这里是丹成学长的毕设系列文章 对毕设有任何疑问都可以
  • HTML-CSS笔记_0424

    HTML CSS 学习笔记源码 链接 https pan baidu com s 1PRorRSlAW0PSHM4grOoapg 提取码 fnr2 HTML 一 网页的基本结构和基础 1 html基础
  • qt 实现UDP通信简单案例

    实现效果 实现功能 创建两个界面 可以通过udp进行通信 并显示通信内容 界面部分由代码实现 并使用qss简单美化 udp通信由创建套接字 绑定端口号 发送和接收数据函数完成 代码实现 创建第一个通信对象 ud1 h ifndef UDPU
  • 【CV学习笔记】onnx篇之DETR

    1 摘要 本次学习内容主要学习了DETR的网络结构 损失函数等知识 明白了DETR是如何做到了端到端的检测 确实是一个十分优雅的框架 同时将DETR利用onnxtime进行推理 对于transformer的理解进一步加深了 DETR学习链接