SSD(pytorch)自建数据集训练及测试

2023-10-30

一、数据集准备

SSD代码:GitHub - amdegroot/ssd.pytorch: A PyTorch Implementation of Single Shot MultiBox Detector

采用的VOC格式的数据集,在data文件夹下新建文件夹VOCdevkit/VOC2007,数据集放在该路径下。数据集包括Annotations(放xml文件)、ImageSets、JPEGImages(放图片),ImageSets下又Main,包含test.txt、train.txt、val.txt、trainval.txt,用于划分数据集。

yolo格式数据集转VOC格式的代码如下:

24行:更改类别名,顺序要按yolo标注的顺序写

67、101行:更改图片格式

107行:更改图片的路径

216、218、220行:更改文件夹路径地址

# -*- coding: utf-8 -*-

import os
import xml.etree.ElementTree as ET
from xml.dom.minidom import Document
import cv2

'''
import xml
xml.dom.minidom.Document().writexml()
def writexml(self,
             writer: Any,
             indent: str = "",
             addindent: str = "",
             newl: str = "",
             encoding: Any = None) -> None
'''

class YOLO2VOCConvert:
    def __init__(self, txts_path, xmls_path, imgs_path):
        self.txts_path = txts_path   # 标注的yolo格式标签文件路径
        self.xmls_path = xmls_path   # 转化为voc格式标签之后保存路径
        self.imgs_path = imgs_path   # 读取读片的路径各图片名字,存储到xml标签文件中
        self.classes = ['pedestrian', 'cyclist', 'car', 'large vehicle']

    # 从所有的txt文件中提取出所有的类别, yolo格式的标签格式类别为数字 0,1,...
    # writer为True时,把提取的类别保存到'./Annotations/classes.txt'文件中
    def search_all_classes(self, writer=False):
        # 读取每一个txt标签文件,取出每个目标的标注信息
        all_names = set()
        txts = os.listdir(self.txts_path)
        # 使用列表生成式过滤出只有后缀名为txt的标签文件
        txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
        print(len(txts), txts)
        # 11 ['0002030.txt', '0002031.txt', ... '0002039.txt', '0002040.txt']
        for txt in txts:
            txt_file = os.path.join(self.txts_path, txt)
            with open(txt_file, 'r') as f:
                objects = f.readlines()
                for object in objects:
                    object = object.strip().split(' ')
                    print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
                    all_names.add(int(object[0]))
            # print(objects)  # ['2 0.506667 0.553333 0.490667 0.658667\n', '0 0.496000 0.285333 0.133333 0.096000\n', '8 0.501333 0.412000 0.074667 0.237333\n']

        print("所有的类别标签:", all_names, "共标注数据集:%d张" % len(txts))

        return list(all_names)

    def yolo2voc(self):
        # 创建一个保存xml标签文件的文件夹
        if not os.path.exists(self.xmls_path):
            os.mkdir(self.xmls_path)

        # 把上面的两个循环改写成为一个循环:
        imgs = os.listdir(self.imgs_path)
        txts = os.listdir(self.txts_path)
        txts = [txt for txt in txts if not txt.split('.')[0] == "classes"]  # 过滤掉classes.txt文件
        print(txts)
        # 注意,这里保持图片的数量和标签txt文件数量相等,且要保证名字是一一对应的   (后面改进,通过判断txt文件名是否在imgs中即可)
        if len(imgs) == len(txts):   # 注意:./Annotation_txt 不要把classes.txt文件放进去
            map_imgs_txts = [(img, txt) for img, txt in zip(imgs, txts)]
            txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
            print(len(txts), txts)
            for img_name, txt_name in map_imgs_txts:
                # 读取图片的尺度信息
                img_name=txt_name.split('.')[0] + '.jpg'
                print("读取图片:", img_name)
                img = cv2.imread(os.path.join(self.imgs_path, img_name))
                height_img, width_img, depth_img = img.shape
                print(height_img, width_img, depth_img)   # h 就是多少行(对应图片的高度), w就是多少列(对应图片的宽度)

                # 获取标注文件txt中的标注信息
                all_objects = []
                txt_file = os.path.join(self.txts_path, txt_name)
                with open(txt_file, 'r') as f:
                    objects = f.readlines()
                    for object in objects:
                        object = object.strip().split(' ')
                        all_objects.append(object)
                        print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']

                # 创建xml标签文件中的标签
                xmlBuilder = Document()
                # 创建annotation标签,也是根标签
                annotation = xmlBuilder.createElement("annotation")

                # 给标签annotation添加一个子标签
                xmlBuilder.appendChild(annotation)

                # 创建子标签folder
                folder = xmlBuilder.createElement("folder")
                # 给子标签folder中存入内容,folder标签中的内容是存放图片的文件夹,例如:JPEGImages
                folderContent = xmlBuilder.createTextNode(self.imgs_path.split('/')[-1])  # 标签内存
                folder.appendChild(folderContent)  # 把内容存入标签
                annotation.appendChild(folder)   # 把存好内容的folder标签放到 annotation根标签下

                # 创建子标签filename
                filename = xmlBuilder.createElement("filename")
                # 给子标签filename中存入内容,filename标签中的内容是图片的名字,例如:000250.jpg
                filenameContent = xmlBuilder.createTextNode(txt_name.split('.')[0] + '.jpg')  # 标签内容
                filename.appendChild(filenameContent)
                annotation.appendChild(filename)
                
                #path
                path = xmlBuilder.createElement("path")
                pathContent = xmlBuilder.createTextNode('/home/seucar/Sunyx/ssd.pytorch-master/data/VOCdevkit/VOC2007/JPEGImages/'+txt_name.split('.')[0] + '.jpg')
                path.appendChild(pathContent)
                annotation.appendChild(path)
                
                #source
                source=xmlBuilder.createElement("source")
                database = xmlBuilder.createElement("database")
                databaseContent = xmlBuilder.createTextNode('Unknown')
                database.appendChild(databaseContent)
                source.appendChild(database)
                annotation.appendChild(source)

                # 把图片的shape存入xml标签中
                size = xmlBuilder.createElement("size")
                # 给size标签创建子标签width
                width = xmlBuilder.createElement("width")  # size子标签width
                widthContent = xmlBuilder.createTextNode(str(width_img))
                width.appendChild(widthContent)
                size.appendChild(width)   # 把width添加为size的子标签
                # 给size标签创建子标签height
                height = xmlBuilder.createElement("height")  # size子标签height
                heightContent = xmlBuilder.createTextNode(str(height_img))  # xml标签中存入的内容都是字符串
                height.appendChild(heightContent)
                size.appendChild(height)  # 把width添加为size的子标签
                # 给size标签创建子标签depth
                depth = xmlBuilder.createElement("depth")  # size子标签width
                depthContent = xmlBuilder.createTextNode(str(depth_img))
                depth.appendChild(depthContent)
                size.appendChild(depth)  # 把width添加为size的子标签
                annotation.appendChild(size)   # 把size添加为annotation的子标签
                
                #segmented
                segmented=xmlBuilder.createElement("segmented")
                segmentedContent = xmlBuilder.createTextNode('0')
                segmented.appendChild(segmentedContent)
                annotation.appendChild(segmented)

                # 每一个object中存储的都是['2', '0.506667', '0.553333', '0.490667', '0.658667']一个标注目标
                for object_info in all_objects:
                    # 开始创建标注目标的label信息的标签
                    object = xmlBuilder.createElement("object")  # 创建object标签
                    # 创建label类别标签
                    # 创建name标签
                    imgName = xmlBuilder.createElement("name")  # 创建name标签
                    imgNameContent = xmlBuilder.createTextNode(self.classes[int(object_info[0])])
                    imgName.appendChild(imgNameContent)
                    object.appendChild(imgName)  # 把name添加为object的子标签

                    # 创建pose标签
                    pose = xmlBuilder.createElement("pose")
                    poseContent = xmlBuilder.createTextNode("Unspecified")
                    pose.appendChild(poseContent)
                    object.appendChild(pose)  # 把pose添加为object的标签

                    # 创建truncated标签
                    truncated = xmlBuilder.createElement("truncated")
                    truncatedContent = xmlBuilder.createTextNode("0")
                    truncated.appendChild(truncatedContent)
                    object.appendChild(truncated)

                    # 创建difficult标签
                    difficult = xmlBuilder.createElement("difficult")
                    difficultContent = xmlBuilder.createTextNode("0")
                    difficult.appendChild(difficultContent)
                    object.appendChild(difficult)

                    # 先转换一下坐标
                    # (objx_center, objy_center, obj_width, obj_height)->(xmin,ymin, xmax,ymax)
                    x_center = float(object_info[1])*width_img + 1
                    y_center = float(object_info[2])*height_img + 1
                    xminVal = int(x_center - 0.5*float(object_info[3])*width_img)   # object_info列表中的元素都是字符串类型
                    yminVal = int(y_center - 0.5*float(object_info[4])*height_img)
                    xmaxVal = int(x_center + 0.5*float(object_info[3])*width_img)
                    ymaxVal = int(y_center + 0.5*float(object_info[4])*height_img)

                    # 创建bndbox标签(三级标签)
                    bndbox = xmlBuilder.createElement("bndbox")
                    # 在bndbox标签下再创建四个子标签(xmin,ymin, xmax,ymax) 即标注物体的坐标和宽高信息
                    # 在voc格式中,标注信息:左上角坐标(xmin, ymin) (xmax, ymax)右下角坐标
                    # 1、创建xmin标签
                    xmin = xmlBuilder.createElement("xmin")  # 创建xmin标签(四级标签)
                    xminContent = xmlBuilder.createTextNode(str(xminVal))
                    xmin.appendChild(xminContent)
                    bndbox.appendChild(xmin)
                    # 2、创建ymin标签
                    ymin = xmlBuilder.createElement("ymin")  # 创建ymin标签(四级标签)
                    yminContent = xmlBuilder.createTextNode(str(yminVal))
                    ymin.appendChild(yminContent)
                    bndbox.appendChild(ymin)
                    # 3、创建xmax标签
                    xmax = xmlBuilder.createElement("xmax")  # 创建xmax标签(四级标签)
                    xmaxContent = xmlBuilder.createTextNode(str(xmaxVal))
                    xmax.appendChild(xmaxContent)
                    bndbox.appendChild(xmax)
                    # 4、创建ymax标签
                    ymax = xmlBuilder.createElement("ymax")  # 创建ymax标签(四级标签)
                    ymaxContent = xmlBuilder.createTextNode(str(ymaxVal))
                    ymax.appendChild(ymaxContent)
                    bndbox.appendChild(ymax)

                    object.appendChild(bndbox)
                    annotation.appendChild(object)  # 把object添加为annotation的子标签
                f = open(os.path.join(self.xmls_path, txt_name.split('.')[0]+'.xml'), 'w')
                xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
                f.close()

if __name__ == '__main__':
    # 把yolo的txt标签文件转化为voc格式的xml标签文件
    # yolo格式txt标签文件相对路径
    txts_path1 = './labels'
    # 转化为voc格式xml标签文件存储的相对路径
    xmls_path1 = './Annotations'
    # 存放图片的相对路径
    imgs_path1 = './JPEGImages'

    yolo2voc_obj1 = YOLO2VOCConvert(txts_path1, xmls_path1, imgs_path1)
    labels = yolo2voc_obj1.search_all_classes()
    print('labels: ', labels)
    yolo2voc_obj1.yolo2voc()

二、训练

新建文件夹weights,下载预训练权重VGG16_reducedfc_pth。链接:https://pan.baidu.com/s/1c0K1oNly5FUJjTetTQgf_A 
提取码:9cfh

data/conifg.py修改voc里的num_classes和max_iter,类别为自己的类别数+1(背景),最大迭代次数可以适当减小。

data/VOC0712.py修改VOC_CLASSES

ssd.py中修改32行num_classes以及改变pull_item函数如下(解决img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])这行报错,target可能为空):

 

train.py将.data[0]全部替换为.item(),以及如下:

 可能还有别的地方需要修改,但我忘了具体位置了,但根据报错直接搜都能解决,就不一一列举了

三、评价

eval.py的do_python_eval函数做如下修改,可以输出Recall、Precision和mAP(f1也有计算但我没输出,有需要可以自己加)。注意修改recs和precs初始时的类别数(不用加背景)

def do_python_eval(output_dir='output', use_07=True):
    cachedir = os.path.join(devkit_path, 'annotations_cache')
    aps = []
    recs = np.zeros((4, 500000)) #4 represent number of classes
    precs = np.zeros((4, 500000)) #4 represent number of classes
    # The PASCAL VOC metric changed in 2010
    use_07_metric = use_07
    print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    print('1')
    for i, cls in enumerate(labelmap):
        filename = get_voc_results_file_template(set_type, cls)
        rec, prec, ap = voc_eval(
           filename, annopath, imgsetpath.format(set_type), cls, cachedir,
           ovthresh=0.1, use_07_metric=use_07_metric)
        aps += [ap]
        #recs += [rec.mean(0)]
        #precs += [prec.max(0)]
        #print(rec.shape)
        
        rec=rec.reshape(len(rec))
        prec=prec.reshape(len(prec))
        r=np.pad(rec,(0,500000-len(rec)),'constant',constant_values=(0,0))
        p=np.pad(prec,(0,500000-len(prec)),'constant',constant_values=(0,0))
        recs[i] = r
        precs[i] = p
        '''pl.plot(rec, prec, lw=2, 
                    label='{} (AP = {:.4f})'
                          ''.format(cls, ap))'''
        print('AP for {} = {:.4f}'.format(cls, ap))
        with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
            pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
    eps=1e-16
    f1 = 2 * precs * recs / (precs + recs + eps)
    i = f1.mean(0).argmax()
    precs, recs, f1 = precs[:, i], recs[:, i], f1[:, i]
    '''pl.xlabel('Recall')
    pl.ylabel('Precision')
    plt.grid(True)
    pl.ylim([0.0, 1.05])
    pl.xlim([0.0, 1.0])
    pl.title('Precision-Recall')
    pl.legend(loc="upper left")     
    plt.show()'''
    print('Mean AP = {:.4f}'.format(np.mean(aps)))
    print('recall:',recs)
    print('Precision:',precs)
    print('recall:',format(np.mean(recs)))
    print('Precision:',format(np.mean(precs)))
    print('~~~~~~~~')
    print('Results:')
    for ap in aps:
        print('{:.3f}'.format(ap))
    print('{:.3f}'.format(np.mean(aps)))
    print('~~~~~~~~')
    print('')
    print('--------------------------------------------------------------')
    print('Results computed with the **unofficial** Python eval code.')
    print('Results should be very close to the official MATLAB eval code.')
    print('--------------------------------------------------------------')

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

SSD(pytorch)自建数据集训练及测试 的相关文章

随机推荐

  • # Linux 升级g++7.5.0

    Linux 升级g 7 5 0 最近使用需要使用c 14标准 发现自己的编译器版本为gcc 4 8 只支持c 11标准 因此进行升级 升级为可以支持c 17的gcc7 5 0 关于编译器的版本支持问题可以到gcc的官方文档进行查询 下面开始
  • 【机器学习杂烩】投票(Voting)算法在分类问题上的应用(一口气实现Logistic 回归、SVM、KNN、C4.5 决策树、 Adaboost、GBDT 和随机森林,并集成起来)

    投票算法简介 投票是在分类算法中广泛运用的集成学习算法之一 投票主要有硬投票和软投票 两种 硬投票即各基分类器权重相同的投票 其原理为多数投票原则 如果基分类器 的某一分类结果超过半数 则集成算法选择该结果 若无半数结果则无输出 软投票 的
  • (一)Windows下tensorflow安装 win64+Python3.5

    2016年11月29日 TF官方宣布0 12版tensorflow支持原生windows操作系统 不在需要通过Docker进行安装 tf初学者 windows重度依赖用户 可以按照下面这种可行的安装方法进行 一 安装Python 1 通过P
  • JavaWeb远程调试

    目录 远程调试 Centos JDK 工程 启动配置 IDE连接 调试 停止调试 远程调试原理 原理 三种通信方式 Debug参数 第1种Socket Attach调试方法 第2种Socket Listen调试方法 第3种调试方法 断点分类
  • Go_实现HTTP请求和响应

    请求 HTTP 超文本传输协议 请求报文由请求行 请求头部 空行 请求包体4个部分组成 如下图所示 请求行 请求行由请求方法 URL HTTP协议版本组成 它们之间使用空格隔开 常用的HTTP请求方法有GET POST请求 GET GET请
  • Java序列化与JNDI注入

    现阶段公司会进行季度的安全巡检 扫描出来的 Java 相关漏洞 无论是远程代码执行 还是 JNDI 注入 基本都和 Java 的序列化机制有关 本文简单梳理了一下序列化机制相关知识 解释为什么这么多漏洞都和 Java 的序列化有关 以及后续
  • Python:三种格式化字符串的方式,附:小题案例+知识点

    python中常见的格式符 格式符 格式说明 c 将对应的数据格式转化为字符 s 将对应的数据格式化为字符串 d 将对应的数据格式化为整数 u 将对应的数据格式化为无符号整型 o 将对应的数据格式化为八进制数据 x 将对应的数据格式化为十六
  • Python爬取电影信息:Ajax介绍、爬取案例实战 + MongoDB存储

    Ajax介绍 Ajax Asynchronous JavaScript and XML 是一种用于在Web应用程序中实现异步通信的技术 它允许在不刷新整个网页的情况下 通过在后台与服务器进行数据交换 实时更新网页的一部分 Ajax的主要特点
  • 机器学习之文本分类(附带训练集+数据集+所有代码)

    我本次对4类文本进行分类 所有截图代码和数据集最后附带免费下载地址 主要步骤 1 各种读文件 写文件 2 使用jieba分词将中文文本切割 3 对处理之后的文本开始用TF IDF算法进行单词权值的计算 4 去掉停用词 5 贝叶斯预测种类 文
  • 【fpga里Verilog语言的小知识点】

    提示 小白撰写 如有错误希望指出 目录 前言 总结 前言 记录一下学习fpga里Verilog语言的小知识点 持续更新 提示 以下是本篇文章正文内容 1 A B C D 是一个条件运算符 含义是如果 B 为 TRUE 则把 C 连线 A 否
  • C语言实现扫雷(含展开,附源码)

    我相信在做的各位都是玩过扫雷的 游戏玩法就不赘述了 直接进入主题 先思考后敲代码 首先 我将扫雷分为两个棋盘 一个放雷 另一个为玩家猜测盘 这就有同学问了 设置一个棋盘不就完了 这样不就搞复杂了吗 先简短的回答这位同学的问题 因为我的考虑是
  • Docker18:容器互联----link (给新手玩的,进阶方法是 自定义网络)

    容器互联 link 思考一个场景 我们编写了一个微服务 database urla ip 项目不重启 数据库ip换掉了 我们希望可以处理这个问题 可以 名字来进行访问容器 若能实现就可以实现高可用 探究 inspect 其实这个 tomca
  • Sed编辑器

    目录 sed编辑器 命令格式 常用选项 常用操作 打印内容 使用地址 编辑 替换 插入 sed编辑器 sed是一种流编辑器 流编辑器会在编辑器处理数据之前基于预先提供的一组规则来编辑数据流 sed编辑器可以根据命令来处理数据流中的数据 这些
  • matlab的求解方程组函数solve、dsolve、ode系列

    一 求解方程组函数 solve 示例 求解含有参数的方程组 ax by 10 ax by 20 编程实现 syms a b x y s solve a x b y 10 a x b y 20 x y sol x s x sol y s y
  • Kettle变量和自定义java代码的实例应用

    1 kettle properties参数配置数据源连接和FTP连接 由于测试环境和生产环境中数据库连接FTP等配置会在部署过程中变更 所以预先定义成配置项 在配置文件中修改 这样测试和发布将会变得简单 下面以数据库为例说明这类配置的使用
  • sys.exc_info()异常处理Python内置

    sys系统相关的参数和函数 Python内置于系统相关的库 sys exc info 此函数返回一个三值元组 包含当前正在处理的异常信息 返回的信息同时针对当前线程和当前堆栈框架 如果当前堆栈框架不是在处理一个异常 该信息取自调用的堆栈信息
  • 用科学计数法表示二进制小数_数值数据的表示

    在学习之前 我们先来看一段代码的计算 为什么会出现这样的结果呢 std 在计算机中最主要的数据类型有无符号整型 有符号整型 浮点数 布尔值 布尔值 就是真 假 1或者0 在计算机中用最小的1个字节表示 即8位 无符号整型 即我们数学中的正整
  • android本地xml配置,Android Studio创建的Android项目一般需要忽略

    1 idea文件夹 2 gradle文件夹 3 所有的build文件夹 4 所有的 iml文件 5 local properties文件 第一种 在Android Studio中设置 第二种 在项目的 idea workspace xml的
  • Java多线程文件下载器高文件下载速度

    使用多线程下载器提高文件下载速度 在进行大文件下载时 使用多线程技术可以显著提高下载速度 本文将介绍如何使用多线程下载器来实现并行下载 以加快文件下载过程 多线程下载器简介 多线程下载器是一种利用多线程同时下载文件的工具 它将大文件分割成多
  • SSD(pytorch)自建数据集训练及测试

    一 数据集准备 SSD代码 GitHub amdegroot ssd pytorch A PyTorch Implementation of Single Shot MultiBox Detector 采用的VOC格式的数据集 在data文