目标检测pytorch版yolov3五——解码过程和可视化以及predict预测过程

2023-11-03

本篇博客是我学习某位up在b站讲的pytorch版的yolov3后写的,
那位up主的b站的传送门:
https://www.bilibili.com/video/BV1A7411976Z
他的博客的传送门:
https://blog.csdn.net/weixin_44791964/article/details/105310627
他的源码的传送门:
https://github.com/bubbliiiing/yolo3-pytorch
侵删

在这里插入图片描述

解码过程和可视化其实就是画先验框和调整先验框获得最后的预测框。
话不多说,直接上代码(代码都是以13x13的特征层为例来进行解析的)
下面代码是url.py文件里面的代码

#调整先验框的过程就是解码
#decodebox这个类就是对先验框进行调整,每次只能对一个特征层进行解码,
class DecodeBox(nn.Module):
    def __init__(self, anchors, num_classes, img_size):
        super(DecodeBox, self).__init__()
        self.anchors = anchors
        self.num_anchors = len(anchors)
        self.num_classes = num_classes
        self.bbox_attrs = 5 + num_classes
        self.img_size = img_size

    def forward(self, input):
        """
        拿到预测结果以后,就放进这个forward函数,
        这里的input的shape是batchsize, 3x(1+4+num_classes), 13, 13
        3x(1+4+num_classes)分析:
        3是代表3个先验框,1代表先验框内部是否包含有物体,4表示先验框的调整参数,num_classes表示先验框内部物体的种类
        """
        #判断一共有多少张图片
        batch_size = input.size(0)
        #然后得到特征层的宽和高,根据我们的例子,这里的宽和高都是13和13
        input_height = input.size(2)
        input_width = input.size(3)

        # 计算步长,这里的步长其实就是输入进来的图片的大小除以我们输入进来的特征层,这里步长的别名也叫感受野
        """
        步长也就是每一个特征点对应原图上有多少个像素
        如我们的例子,我们将原图划分为13x13的网格,一张原图有412的像素,那么每一个特征点就对应412/13=32个像素点。(这里除以13是因为我们需要分开计算宽和高)
        那么就可以分别计算出高和宽的步长都是32
        """
        stride_h = self.img_size[1] / input_height
        stride_w = self.img_size[0] / input_width
        # 归一到特征层上
        """
        这里就是对先验框的样式进行调整
        """
        scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors]

        # 对预测结果进行resize,进行通道转换和reshape
        """
        batchsize, 3x(5+num_classes), 13, 13->batchsize, 3, 13, 13, (5+num_classes)
        下面self.num_anchors表示的是3,也就是先验框的个数
        self.bbox_attrs也就是5+num_classes
        """
        prediction = input.view(batch_size, self.num_anchors,
                                self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()

        #下面步骤就是获得先验框的调整参数

        # 获得先验框的中心位置的调整参数,先验框的中心其实就是我们划分网格的时候网格与网格之间的交点
        #中心位置就是先验框和预测框中心的偏移距离
        #在这里加上一个sigmoid可以将我们的值固定在0和1之间,这样我们的先验框的中心就只会往右下角的网格偏移了
        x = torch.sigmoid(prediction[..., 0])  
        y = torch.sigmoid(prediction[..., 1])
        # 获得先验框的宽高调整参数,就是先验框的大小调整,调整到预测框的大小
        w = prediction[..., 2]  # Width
        h = prediction[..., 3]  # Height

        # 获得置信度,是否有物体
        conf = torch.sigmoid(prediction[..., 4])
        # 种类置信度
        pred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.

        FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
        LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor

        # 下面就是生成网格,生成先验框
        #首先是生成先验框的中心,也就是每个网格相交的网格点,它的shape是:batch_size,3,13,13(也就是13x13的网格,每个网格有三个先验框)
        grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_width, 1).repeat(
            batch_size * self.num_anchors, 1, 1).view(x.shape).type(FloatTensor)
        grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_height, 1).t().repeat(
            batch_size * self.num_anchors, 1, 1).view(y.shape).type(FloatTensor)

        # 生成先验框的宽高
        anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
        anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
        anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
        anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)

        # 计算调整后的先验框中心与宽高
        pred_boxes = FloatTensor(prediction[..., :4].shape)
        #对先验框的中心进行调整
        pred_boxes[..., 0] = x.data + grid_x
        pred_boxes[..., 1] = y.data + grid_y
        #对先验框的宽高进行调整
        pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
        pred_boxes[..., 3] = torch.exp(h.data) * anchor_h

        # 用于将输出调整为相对于416x416的大小
        _scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor)
        output = torch.cat((pred_boxes.view(batch_size, -1, 4) * _scale,
                            conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)

下面就是predict过程的代码了
下面代码是predict.py 文件的代码

#首先创建了yolo这个类,这个类就是我们在yolo.py文件里面创建的类
yolo = YOLO()

while True:
    img = input('Input image filename:')
    try:
        image = Image.open(img)
    except:
        print('Open Error! Try again!')
        continue
    else:
    	#detect_image对我们输入进来的图片进行一个预测,然后把预测框给画出来,
        r_image = yolo.detect_image(image)
        r_image.show()

下面的代码文件是出现在yolo.py 文件

def detect_image(self, image):
        #首先获得了输入进来的图片的高和宽是多少
        image_shape = np.array(np.shape(image)[0:2])
        #添加灰条
        crop_img = np.array(letterbox_image(image, (self.model_image_size[0],self.model_image_size[1])))
        photo = np.array(crop_img,dtype = np.float32)
        #对图片进行归一化
        photo /= 255.0
        photo = np.transpose(photo, (2, 0, 1))
        photo = photo.astype(np.float32)
        images = []
        images.append(photo)

        images = np.asarray(images)
        images = torch.from_numpy(images)
        if self.cuda:
            images = images.cuda()
        
        with torch.no_grad():
            outputs = self.net(images)
            output_list = []
            for i in range(3):
                output_list.append(self.yolo_decodes[i](outputs[i]))
            #这里的cat就是对我们三个预测结果进行一次堆叠
            output = torch.cat(output_list, 1)
            batch_detections = non_max_suppression(output, self.config["yolo"]["classes"],
                                                    conf_thres=self.confidence,
                                                    nms_thres=0.3)
        try :
            batch_detections = batch_detections[0].cpu().numpy()
        except:
            return image
        top_index = batch_detections[:,4]*batch_detections[:,5] > self.confidence
        top_conf = batch_detections[top_index,4]*batch_detections[top_index,5]
        top_label = np.array(batch_detections[top_index,-1],np.int32)
        top_bboxes = np.array(batch_detections[top_index,:4])
        top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)

        # 去掉灰条
        boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)

        font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))

        thickness = (np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0]

        for i, c in enumerate(top_label):
            predicted_class = self.class_names[c]
            score = top_conf[i]

            top, left, bottom, right = boxes[i]
            top = top - 5
            left = left - 5
            bottom = bottom + 5
            right = right + 5

            top = max(0, np.floor(top + 0.5).astype('int32'))
            left = max(0, np.floor(left + 0.5).astype('int32'))
            bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
            right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))

            # 画框框
            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            label = label.encode('utf-8')
            print(label)
            
            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle(
                    [left + i, top + i, right - i, bottom - i],
                    outline=self.colors[self.class_names.index(predicted_class)])
            draw.rectangle(
                [tuple(text_origin), tuple(text_origin + label_size)],
                fill=self.colors[self.class_names.index(predicted_class)])
            draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
            del draw
        return image


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

目标检测pytorch版yolov3五——解码过程和可视化以及predict预测过程 的相关文章

随机推荐

  • ipynb文件_vscode 里建立ipynb文件

    最近在学习并熟悉vscode的操作使用方法 记录一下 Jupyter介绍 Jupyter Notebook 此前被称为 IPython notebook 是一个交互式笔记本 支持运行 40 多种编程语言 对于希望编写漂亮的交互式文档的人来说
  • 03.1 使用普通表单向Spring控制器提交数据

    03 1 使用普通表单向Spring控制器提交数据 场景 由前台jsp网页 表单 向数据库中添加一条信息 前台jsp
  • IDEA的import类和pom文件头被标红,但可以正常编译打包(四种解决方案)

    IDEA的import类和pom文件头被标红 但可以正常编译打包 四种解决方案 问题背景 方案一 方案二 方案三 方案四 心得 Lyric 雨点从两旁划过 问题背景 昨晚回家没有关电脑 也没关IDEA 今早看IDEA的时候 居然莫名其妙出现
  • 写1清0与写0清零:单片机中断服务函数为什么要用写1清零中断标志位?

    前记 第一次使用risc的单片机 照着datasheet和demo边研究边写 因为之前使用51单片机基本都是照着demo CTRL C V 然后自己改改逻辑 这样一个项目也就差不多了 很多原理其实没搞太清楚 借着这个机会 好好补一补 原理搞
  • Maven安装教程

    一 下载安装包Maven Download Apache Mavenhttps maven apache org download cgi 二 配置maven环境 1 将压缩包放到自己想要存放的目录 2 复制Maven的根路径 注意不是bi
  • Raki的读paper小记:RWKV: Reinventing RNNs for the Transformer Era

    Abstract Introduction Related Work 研究任务 基础模型架构 已有方法和相关工作 RNN CNN Transformer 稀疏注意力 Beltagy等人 2020年 Kitaev等人 2020年 Guo等人
  • GLES3.0中文API-glGetProgramResourceName

    名称 glGetProgramResourceName 查询程序中已索引资源的名称 C 规范 void glGetProgramResourceName GLuint program GLenum programInterface GLui
  • 接口api 之Swagger 一次实战探索

    今天我们来说说什么是Swagger 就是把相关的信息存储在它定义的描述文件里面 yml或json格式 再通过维护这个描述文件可以去更新接口文档 以及生成各端代码 而Springfox swagger 则可以通过扫描代码去生成这个描述文件 好
  • 问题 E: [蓝桥杯2016初赛]交换瓶子

    题目描述 有N个瓶子 编号 1 N 放在架子上 比如有5个瓶子 2 1 3 5 4 要求每次拿起2个瓶子 交换它们的位置 经过若干次后 使得瓶子的序号为 1 2 3 4 5 对于这么简单的情况 显然 至少需要交换2次就可以复位 如果瓶子更多
  • STM32 基本定时器实验

    1 基本定时器简介 时钟源 时钟挂载在APB1总线下 中间有一个倍频器 sys stm32 clock init时钟已经设置APB1总线时钟频率为36M 预分频器分频系数为2 所以挂载在APB1总线的定时器时钟频率为72Mhz 图中对应的时
  • node mysql 连接 时区_Nodejs Date 保存到mysql中时区问题,处理方法

    nodejs中mysql用法 1 建立数据库连接 createConnection Object 方法 该方法接受一个对象作为参数 该对象有四个常用的属性host user password database 与php中链接数据库的参数相同
  • ArrayLIst、HashMap

    底层维护了一个Objec的数组 创建对象时 初始大小是0 第一次新增元素时扩容为10 再次扩容为1 5倍 扩容的时机是内部数组满了之后 再次add才会扩容 非线程安全 线程安全的Vector HashMap jdk7以前为数组 链表 搜索的
  • 数据结构知识点汇总

    1 用链表表示线性表的优点是 便于插入和删除操作 2 单链表中 增加头结点的目的是 方便运算的实现 3 栈和队列的共同特点是 只允许在端点处插入和删除元素 4 栈通常采用的两种存储结构是 线性存储结构和链表存储结构 5 队列具有 先进先出
  • Lamport 逻辑时钟

    分布式系统中按是否存在节点交互可分为三类事件 一类发生于节点内部 二是发送事件 三是接收事件 注意 以下文章中提及的时间戳如无特别说明 都指的是Lamport 逻辑时钟的时间戳 不是物理时钟的时间戳 如果a在进程Pi中 b在进程Pj中 Ci
  • 今日分享积累的5个AI绘画网站,好用且免费

    AI绘画即基于人工智能的绘画技术 让设计师能够以全新的方式创作出惊人的艺术作品 而随着AI绘画技术的发展 市面上也多了很多能免费使用的AI绘画网站 可以为我们提供更多的绘画灵感和创作可能性 接下来我将为大家推荐5个能免费使用的AI绘画网站
  • ngrok搭建服务器(超级详细)

    前言 我一直都在usr local文件下操作 有不懂的同学给我留言 我没有修改源码 只是测试能否生成服务端文件 有需要的同学可以修改源码 使用 ip 做域名时 随机生成的子域名导致地址错误解决办法就是改源码 去掉随机生成 在ngrok目录下
  • WAIC2023:图像内容安全黑科技助力可信AI发展

    目录 0 写在前面 1 AI图像篡改检测 2 生成式图像鉴别 2 1 主干特征提取通道 2 2 注意力模块 2 3 纹理增强模块 3 OCR对抗攻击 4 助力可信AI向善发展 总结 0 写在前面 2023世界人工智能大会 WAIC 已圆满结
  • python insert插入新一列

    mydata insert 1 date data 日期 mydata 原有数据 1 插入第几列 data 插入列名 data 日期 插入列内容 原有数据插入一列 mydata insert 1 date data 日期 mydata 原有
  • 快速构建Kubesphere 3.0并设置Kubesphere 多集群联邦

    这里我们Host选择使用单节点All in One安装模式 可以零配置快速部署 KubeSphere和Kubernetes 我们安装联邦集群需要有一台节点进行管理 Member需要在Kubernetes中安装Kubesphere当作Memb
  • 目标检测pytorch版yolov3五——解码过程和可视化以及predict预测过程

    本篇博客是我学习某位up在b站讲的pytorch版的yolov3后写的 那位up主的b站的传送门 https www bilibili com video BV1A7411976Z 他的博客的传送门 https blog csdn net