Detr代码解读(一)数据加载

2023-11-11

导读

源码:Detr
Detr用的数据集格式为coco数据集格式,所以在进行Detr算法前,需要将我们的数据集转为coco数据集,可参考VOC转COCO数据集
在这里插入图片描述
Detr模型如上图所示,简单来讲,就是将图片进行CNN特征提取后送入Transformer模块,然后直接生成预测框,设置个阈值后,生成的预测框就是预测结果,不需要再经历NMS操作。

下载完源码后,直接打开项目,Detr项目依赖了几个作者自己做的第三方库,所以首先打开项目里的README.md文件。
在这里插入图片描述
找到Usage - Object detection部分,按着这部分提供的安装指令,将第三方库安装到编译器环境下。这里注意,可能输入指令会报错,原因可能是你没有安装gitgit的安装可以去查一下。
在这里插入图片描述
环境配置完毕后就可以运行main.py文件开始运行整个项目。
在运行main.py前需要将对它的初始参数进行设置,比如batch_size,数据来源等等。两种方式可以设置,第一,直接在代码每个参数后的default进行设置:
在这里插入图片描述
调试的时候也可以通过Debug里修改运行配置:
在这里插入图片描述
main.py设置完数据源后就可以运行项目了。
如果出现数据文件找不到之类的报错:
在这里插入图片描述
直接翻到本文目录。在这里插入图片描述

数据加载

这一部分先来看数据集是怎么加载到模型里的。之前写过一篇数据加载机理的

数据集加载部分SSDDataset类代码解读-数据采样机理

其实不管是什么检测算法,用的都是DataLoader进行数据采样。
直接在main.py文件里找到DataLoader,代码如下

  data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers)

DataLoader中设置了数据集dataset_train,采样策略batch_sampler(可选择随机采样还是顺序采样),collate_fn收集策略(读取的数据进行初步的处理后,以什么的格式返回给上层的函数)num_workers(告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关))

dataset_train这个实例可以认为是一个数据迭代器,DataLoader告诉dataset_train迭代器需要随机采样16张图片,dataset_train就会迭代16次,每一次去读取一张图片,并对这张图进行初步的处理

dataset_train

ctrl+F直接在main.py中搜索dataset_tarin:

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

ctrl+鼠标左键进入build_dataset函数:

def build_dataset(image_set, args):
    if args.dataset_file == 'coco':
        return build_coco(image_set, args)
    if args.dataset_file == 'coco_panoptic':
        # to avoid making panopticapi required for coco
        from .coco_panoptic import build as build_coco_panoptic
        return build_coco_panoptic(image_set, args)
    raise ValueError(f'dataset {args.dataset_file} not supported')

本文的数据集文件是coco,所以直接运行前两句,进入build_coco函数
在这里插入图片描述

数据源文件不存在报错解决

这里有个点要注意PATHS是现在数据集下面训练集跟验证集的名称,如果跟你数据集名称不一样,可以在这里改成你数据集下的名称。
在这里插入图片描述

这一步就将数据集的路径跟dataset迭代器联系起来了,之后迭代器就可以根据指令,去读取数据集路径下的数据了。

CocoDetection

这一部分才是dataset_tarin这个数据迭代器的核心内容,

在这里插入图片描述

dataset_train可以根据DataLoader进行采样对应的函数就是CocoDetection中的def __getitem__(self, idx):函数,这个函数会根据DataLoader下发的idx索引进行采样。

在这里插入图片描述
CocoDetection中设置了图象变换的函数,以及是数据处理函数:
在这里插入图片描述
初始值设置好后,返回main.py查找train_one_epoch,数据通过train_one_epoch进行读取。train_one_epoch也是正式的一轮训练。

        train_stats = train_one_epoch(
            model, criterion, data_loader_train, optimizer, device, epoch,
            args.clip_max_norm)

打开train_one_epoch函数在for samples, targets in metric_logger.log_every(data_loader, print_freq, header):这行指令中进行数据集的采样。通过for inDataLoader进行访问,DataLoader根据所配置的内容对dataset_train进行采样。

在这里插入图片描述
将调用DataLoaderfor in写在了metric_logger.log_every里:
在这里插入图片描述

采样

DataLoader随机生成一个索引传给数据迭代器CocoDetection,通过def __getitem__(self, idx):对这个索引对应的数据进行处理:
下图中的super(CocoDetection, self).__getitem__(idx)这个函数是COCO自己的第三方库里的一个函数,可以根据索引直接获取数据集中对应的图片跟他的标注。img就是这个索引对应的图片,target就是对应的标注。

在这里插入图片描述
在这里插入图片描述
进一步通过img, target = self.prepare(img, target)对图片跟标注进行处理:

        img, target = self.prepare(img, target)

数据处理

处理target

对输入self.prepare的图片跟标注进行处理,首先先获取图片id,然后获取标注信息,获取目标类别,然后将这些转为tensor格式。
ConvertCocoPolysToMask将每个image对应的target转化成为一个dict,这个dict中保存了该图片的所有标注信息,其中对于目标检测的有用信息是boxes和labels,其如下:
在这里插入图片描述
因为研究的是目标检测,而不是关键点检测跟实例分割所以跳过跟keypoints、self.return_masks有关的代码,如下面两段代码:

        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

然后将获取到的标注信息重新整理放到target字典中

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        target["image_id"] = image_id

此外又做了一个标注内容面积的计算,还有尺寸信息等,是为了之后利用COCOAPI计算检测的map值准备的。

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

得到的target内如如下:
在这里插入图片描述
通过img, target = self.prepare(img, target)使得target从标注的格式变成dict类型:
在这里插入图片描述

处理img

图片处理这边用到transforms在这里插入图片描述
这是一种常见的图片操作函数,直接看make_coco_transforms函数
在这里插入图片描述
简单来讲就是将输入进来的图片转为tensor格式然后进行标准化,再判断其是train模式还是val模式,如果是tarin模式,则先对图片进行翻转,再随机的缩放等操作。最后将图片返回。具体代码如下:

def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),#随机水平翻转给定的PIL.image 概率为0.5
            T.RandomSelect(#随机选择两个字操作之一
                T.RandomResize(scales, max_size=1333),#随机选择sizes中的一个值作为图象短边,并保持比例,最大不超过max_size
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),#对图片进行随机尺寸裁剪,最后缩放到统一大小
                    T.RandomResize(scales, max_size=1333),#
                ])
            ),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')

收集策略collate_fn(batch)

DataLoadertrainset训练数据集进行一个batch的采样后,将采到的数据传入collate_fn(batch)函数,将这一批数据进行整理一下返回给上一级。
如下图所示,DataLoader一批采样两张图片:
在这里插入图片描述
这里对图片又进行了nested_tensor_from_tensor_list操作,对于每一批次的图像,首先找出每一批次图片的H,W的最大值Hmax,Wmax,然后将原始图像填充为3HmaxWmax大小,并将图像部分置为False,填充部分置为True.最后将图像数据tensor和mask打包为nesttensor格式。
这样子使得每批次输出的图片的大小一致。

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

处理后的图片如下图所示:
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Detr代码解读(一)数据加载 的相关文章

  • git pull origin master 返回致命错误:无效的 refspec

    问题是这样的 每当我这样做时 git pull https github com username reponame github io git 接下来是网址 我没有遇到任何问题 但是当我这样做时 git pull origin maste
  • 如何将“develop”分支推送到远程“origin”?

    当我做git flow init它创造了一个master and develop分支机构 当我添加遥控器时git remote add origin email protected cdn cgi l email protection Ne
  • 无法更改 GitHub Pages 中的源分支

    我为 GitHub Pages 创建了一个简单的网站 该网站的源位于 master 分支中 生成的网站 我希望看到发布的 位于 gh pages 分支下 我希望能够在设置中更改网站的来源 但是设置是灰色的 我无法更改它 请参见下面的屏幕截图
  • GitHub 恢复或重置? [复制]

    这个问题在这里已经有答案了 正如您在图片中看到的 我正在功能 forum kolo 3 中工作 我决定完成该功能并将其合并到开发中 但没有将更改推送到远程开发 因此它只是本地更改 然后我意识到这是一个坏主意 现在我想删除这个合并 就像它从未
  • 说它已提交,但在 GitHub 中它没有显示

    我刚刚在 Ubuntu 机器上安装了 Rails 我设置了 git 并创建了一个 ssh 密钥来链接到我的帐户 我创建了一个要提交的存储库 并创建了一个示例项目来测试 名为first app 当我提交时 它说一切都已提交 但我去了 gith
  • iOS CoreData:“数据模型版本编译器”错误

    我在项目中创建了一个数据模型文件 ChatModel xcdatamodeld 然后我合并了github上的分支 project pbxproj 中存在冲突 我修好了它们 然后错误就发生了 Users mac zhongqing ios Z
  • 使用Chrome访问github,无法加载css和js。但IE没问题

    我的 Chrome 版本 50 0 2661 75 m 访问GitHub 无法加载css和javascript 错误 CSS stylesheet from origin https assets cdn github com has be
  • 如何在一台电脑上拥有2个git用户?

    我想练习使用 GitHub 做拉取请求并学习如何观看git 差异不同用户之间 如何在 macOS 的终端上设置另一个用户帐户来执行此操作 如何在用户之间切换 充当第二用户有三个方面 1 GitHub账户 要以其他用户身份使用 GitHub
  • Github 版本如何生成存档文件名?

    我刚刚在 github 上为我的 NFQL 软件创建了版本 这是发布页面 https github com vbajpai nfql releases https github com vbajpai nfql releases 对于最新版
  • 签出现有的远程分支

    我见过不同的方法来检查现有的远程分支 假设我的朋友推送了新分支 bigbug 并且我想签出并将我的本地工作副本切换到该分支 我有以下选项 1 git checkout b bigbug origin bigbug 2 git checkou
  • git 克隆错误:致命:git upload-pack:由于远程端可能的存储库损坏而中止

    我对 git 存储库具有读 写访问权限 但是当我尝试 git clone 时 出现以下错误 x ubuntu temp git clone email protected cdn cgi l email protection Corp ap
  • 在 GitHub 上执行拉取请求时避免不需要的合并提交和其他提交

    我在 Github 上分叉了一个项目 令远程上游为upstream我的远程存储库是origin 我当地的master分支设置为跟踪远程master分支 然后我在本地添加了一些东西master 时不时与上游汇合 直到今天我想发出pull re
  • 将 github 上的包安装到 Spyder 中

    我一直在尝试安装并导入mpl finance来自 github 的包 在我的 Spyder 环境中没有成功 我努力了 pip install e git https github com matplotlib mpl finance git
  • 无法使用 git 推送或获取 [重复]

    这个问题在这里已经有答案了 我可以拉 但无法使用 git 版本 1 9 5 推送或获取 它突然开始给我以下错误 关于如何修复它有什么想法吗 git fetch fatal unable to access https email prote
  • GIT - 推送到 (GitHub) origin master 没有任何作用

    我已经分叉了某人的 GIT 存储库 https github com nippysaurus toodledo objc 将其克隆到我的本地计算机 显示带有以下信息的来源 remote origin Fetch URL https emai
  • 如何在 GitHub Action 中使用不同版本的 PHP 进行测试

    我有一些 PHP 代码 其中包含使用以下命令运行的测试PHPUnit并想对其进行测试GitHub Actions 我在他们的文档中找不到测试 PHP 包的方法 我想使用不同版本的 PHP 进行测试 但他们只有最新的版本7 3安装 您可以添加
  • 如何禁用 GitHub 中的拉取请求?

    我试图了解如何禁用 github 中的 拉取请求 问题一 我们正在尝试使用变基工作流程 这意味着如果不是快速推进 那么使用拉取请求可能会有害 一种解决方案 为我想要禁用拉取请求的分支设置分支权限 或者将我添加为任何进入 master 的内容
  • 从分叉存储库的 GitHub 操作发布评论的解决方法

    我需要在 GitHub 操作完成后向 GitHub 拉取请求发表评论 例如当 FOSS 社区成员提交 PR 时 我知道 当操作从分叉的存储库运行时 令牌没有对父存储库的写访问权限 因此它无法发布评论 人们是否为此找到了任何可行的解决方法 我
  • 在 Windows 上使用 Git - 意外丢失了大量工作。我可以拿回来吗?

    我很困惑 我想我已经失去了几个小时的工作时间 我之前在 Git 中编辑了一个文件 我保存了它 但没有提交 我确实做了一些其他文件更改 并提交并推送了它们 然而 有一个文件被搞乱了 所以我单击了最后一次成功的提交 然后按了 回滚到此提交 令我
  • 如何在同一存储库中的 github 操作之间共享代码?

    假设我想要两个工作流程build yml and release yml在我的仓库中 第一个应该构建项目 假设使用 CMake 第二个应该构建项目并使用构建的二进制文件创建 GitHub 版本 项目构建代码在两个文件之间重复 如何在它们之间

随机推荐

  • 深入理解C++中的move和forward!

    导语 在C 11标准之前 C 中默认的传值类型均为Copy语义 即 不论是指针类型还是值类型 都将会在进行函数调用时被完整的复制一份 对于非指针而言 开销及其巨大 因此在C 11以后 引入了右值和Move语义 极大地提高了效率 本文介绍了在
  • 保险智能理赔-医疗票据OCR识别解决方案

    基于对健康险理赔行业的深刻洞察和理解 以领先的医疗AI数智化能力打通健康险理赔全流程 通过RPA人机协作实现对理赔材料的智能录入和初审工作 释放大量的专业录单和审核人力 减少企业运营成本 面临痛点 1 人工录入成本高 涉及个人证件 保险单据
  • Visual Attention Network(VAN)

    Visual Attention Network阅读翻译笔记 原文 https arxiv org abs 2202 09741 代码 https github com Visual Attention Network 摘要 虽然最初是为自
  • oracle 排序后取中间的数据

    用minus取 3000和5000的差值 排序后取序列为3000 5000的数据 select from table name t where rownum lt 5000 order by t sn asc minus select fr
  • php event监听

  • linux挂马检测,检测网站挂马程序(Python)

    系统管理员通常从svn git中检索代码 部署站点后通常首先会生成该站点所有文件的MD5值 如果上线后网站页面内容被篡改 如挂马 等 可以比对之前生成MD5值快速查找去那些文件被更改 为了使系统管理员第一时间发现 可结合crontab或na
  • 树莓派的四种登陆方式

    参考 树莓派的4种登陆方式 作者 丶PURSUING 发布时间 2021 02 02 09 15 30 网址 https blog csdn net weixin 44742824 article details 113524929 spm
  • 大厂被裁5个月,boss招聘上12000次打招呼,终于拿到offer了

    你可能不相信 但这是一个小伙伴最近的真实经历 他是一个有着5年工作经验的滴滴非核心项目运营 曾经在一家知名的互联网公司工作 收入也不错 但就在今年年初 因为后疫情的影响 公司进行了一轮大规模的裁员 他就成了其中的一员 从那以后 他就开始了漫
  • Android OpenGLES2.0(十六)——3D模型贴图及光照处理(obj+mtl)

    在Android OpenGLES2 0 十四 Obj格式3D模型加载中实现了Obj格式的3D模型的加载 加载的是一个没有贴图 没有光照处理的帽子 为了呈现出立体效果 手动 加了光照 拥有贴图的纹理及光照又该怎么加载呢 模型文件 本篇博客例
  • 【JVM】JVM调优(基础篇)

    一 概述 先来说下JVM调优主要是在调啥 调优就是调节JVM运行时内存大小 gc垃圾回收细节 要想调整JVM运行时内存大小 需要我们知道JVM内存划分知识以及要想调整gc垃圾回收的细节 需要我们知道垃圾回收器工作原理以及它们使用的垃圾回收算
  • FlinkCDC第三部分-同步mysql到mysql,ctrl就完事~(flink版本1.16.2)

    本文介绍了 来源单表 gt 目标源单表同步 多来源单表 gt 目标源单表同步 注 1 16版本 1 17版本都可以使用火焰图 生产上最好关闭 详情见文章末尾 Flink版本 1 16 2 环境 Linux CentOS 7 0 jdk1 8
  • cocos2d-x 源码分析 总目录

    这篇博客用来整理与cocos2d x相关的工作 只要有新的分析 扩展或者修改 都会更改此文章 祝大家愉快 1 源码分析 1 CCScrollView源码分析 http blog csdn net u011225840 article det
  • Python破解12306图片验证码

    不知从何时起 12306的登录验证码竟然变成了按字找图 可以说是又提高了一个等次 竟然把图像识别都用上了 不过有些图片 不得不说有些变态 图片的清晰图就更别说了 明显是从网络上的图库中搬过来的 谁知没多久 网络就惊现破解12306图片验证码
  • 删除 GitHub 提交记录中的文件或敏感数据

    BFG Repo Cleaner 是一个简单 快速的 10 720 倍 工具 代替 git filter branch 在你的 Git 存储库中清除不想要的文件或敏感数据 GitHub 地址 https github com rtyley
  • 代码审计-strpos数组绕过

    ereg 函数用指定的模式搜索一个字符串中指定的字符串 如果匹配成功返回true 否则 则返回false 搜索字母的字符是大小写敏感的 strpos 函数查找字符串在另一字符串中第一次出现的位置
  • CloudOS:物联网开发平台,云上开发,边端交付

    文章目录 一 CloudOS概述 二 CloudOS主要功能 1 云上开发 2 边端交付 3 数据管理 4 安全保障 三 CloudOS应用场景 1 智能家居 2 智慧城市 3 工业物联网 四 总结 欢迎来到云计算技术应用专栏 CloudO
  • 工作中如何做好技术积累(摘自美团点评技术团队博客)

    引言 古人云 活到老 学到老 互联网算是最辛苦的行业之一 加班 对工程师来说已是 家常便饭 同时互联网技术又日新月异 很多工程师都疲于应付 叫苦不堪 以至于长期以来流传一个很广的误解 35岁是程序员工作的终点 如何在繁忙的工作中做好技术积累
  • 由于设备驱动程序的前一个实例仍在内存中,Windows 无法加载这个硬件的设备驱动

    感谢这位 sizzg7796 网友
  • 【Jenkins】部署vue项目(多种方式部署)

    文章目录 Jenkins部署vue项目 先安装node js 上传到linux并解压 配置Jenkins 环境变量 jenkins 创建任务 部署方式 第一种 npm run build 打包的形式 执行脚本 build Steps 第二种
  • Detr代码解读(一)数据加载

    文章目录 导读 数据加载 dataset train 数据源文件不存在报错解决 CocoDetection 采样 数据处理 处理 target 处理 img 收集策略 collate fn batch 导读 源码 Detr Detr用的数据