常用数据集预处理(dota)

2023-11-03

  1. 从数据集中选出自己需要的类别
import os
import cv2
import shutil

catogary = ['bridge']   #列表

def customname(fullname):
    """返回不带后缀的文件名"""
    return os.path.basename(os.path.splitext(fullname)[0])

def GetFileFromRoot(dir):
    """获得每个文件的完整路径,包括后缀"""
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            file_path = os.path.join(root, file)
            allfiles.append(file_path)
    return allfiles

if __name__ == '__main__':
    root = 'E:/Aerial Images/Aerial Images/DOTA/train'
    raw_pic_path = os.path.join(root, 'images/images')
    raw_lab_path = os.path.join(root, 'labelTxt-v1.0/labelTxt')
    bridge_pic = os.path.join(root, 'bridge/images')
    bridge_lab = os.path.join(root, 'bridge/labelTxt')

    label_list = GetFileFromRoot(raw_lab_path)
    for label_path in label_list:
        n = 0
        f = open(label_path, 'r')
        lines = f.readlines()
        split_lines = (line.strip().split(' ') for line in lines) #strip 移除字符串头尾指定字符,默认空格,换行符或字符序列;根据空格来分割
        for i, split_line in enumerate(split_lines):
            if i in [0, 1]:    #标签文本前两行为格式及高度,无用
                continue
            catogary_name = split_line[8]   #类别
            if catogary_name in catogary:
                n = n + 1
                if n > 1:    #所要求类别目标数量达到两个就可以将该图像挑选出来
                    name = customname(label_path)  #不带后缀的标签文件名
                    old_label_path = label_path
                    old_img_path = os.path.join(raw_pic_path, name + '.png')

                    img = cv2.imread(old_img_path)
                    new_lab_path = os.path.join(bridge_lab, name + 'txt')
                    new_pic_path = os.path.join(bridge_pic, name + '.png')

                    cv2.imwrite(new_pic_path, img)
                    shutil.copyfile(old_label_path, new_lab_path)

  1. 删除数据集中的空白样本
import os
import shutil
import xml.dom.minidom
 
def custombasename(fullname):
    return os.path.basename(os.path.splitext(fullname)[0])
 
def GetFileFromThisRootDir(dir,ext = None):
  allfiles = []
  needExtFilter = (ext != None)
  for root,dirs,files in os.walk(dir):
    for filespath in files:
      filepath = os.path.join(root, filespath)
      extension = os.path.splitext(filepath)[1][1:]
      if needExtFilter and extension in ext:
        allfiles.append(filepath)
      elif not needExtFilter:
        allfiles.append(filepath)
  return allfiles
  
def cleandata(path, img_path, blank_label_path, blank_img_path, ext, label_ext):
    name = custombasename(path)  #名称
    if label_ext == 'xml':
        DomTree = xml.dom.minidom.parse(path)  
        annotation = DomTree.documentElement  
        objectlist = annotation.getElementsByTagName('object')        
        if len(objectlist) == 0:
            image_path = os.path.join(img_path, name + ext) #样本图片的名称
            shutil.move(image_path, blank_img_path)  #移动该样本图片到blank_img_path
            shutil.move(path, blank_label_path)     #移动该样本图片的标签到blank_label_path
    else:
        f_in =  open(path, 'r')  #打开label文件
        lines = f_in.readlines()
        if len(lines) == 0:  #如果为空
            f_in.close()
            image_path = os.path.join(img_path, name + ext) #样本图片的名称
            shutil.move(image_path, blank_img_path)  #移动该样本图片到blank_img_path
            shutil.move(path, blank_label_path)     #移动该样本图片的标签到blank_label_path
    print('正在处理 %s'%path)
                                                 
if __name__ == '__main__':
    root = 'E:/Aerial Images/Aerial Images/trainsplit'
    img_path = os.path.join(root, 'images')  #分割后的样本集
    label_path = os.path.join(root, 'labelTxt')  #分割后的标签
    ext = '.png' #图片的后缀
    label_ext = '.txt'
    #空白的样本及标签
    blank_img_path = os.path.join(root, 'blank_images')
    blank_label_path = os.path.join(root, 'blank_labelTxt')
    if not os.path.exists(blank_img_path):
        os.makedirs(blank_img_path)
    if not os.path.exists(blank_label_path):
        os.makedirs(blank_label_path)
        
    label_list = GetFileFromThisRootDir(label_path)
    for path in label_list:
        cleandata(path, img_path, blank_label_path, blank_img_path, ext, label_ext)
  1. 删除数据中的非目标样本(提取出含所需目标的样本)
import os
import shutil
import xml.dom.minidom

#n = 0

def custombasename(fullname):
    return os.path.basename(os.path.splitext(fullname)[0])


def GetFileFromThisRootDir(dir, ext=None):
    allfiles = []
    needExtFilter = (ext != None)
    for root, dirs, files in os.walk(dir):
        for filespath in files:
            filepath = os.path.join(root, filespath)
            extension = os.path.splitext(filepath)[1][1:]
            if needExtFilter and extension in ext:
                allfiles.append(filepath)
            elif not needExtFilter:
                allfiles.append(filepath)
    return allfiles


def cleandata(path, img_path, nontarget_label_path, nontarget_img_path, ext, label_ext):
    name = custombasename(path)  # 名称
    n = 0
    f_in = open(path, 'r')  # 打开label文件
    lines = f_in.readlines()
    splitlines = [line.strip().split(' ') for line in lines]
    for i, splitline in enumerate(splitlines):
        catogory_name = splitline[8]
        if catogory_name in catogory:
            n = n + 1
            if n > 0:
                f_in.close()
                image_path = os.path.join(img_path, name + ext)  # 样本图片的名称
                shutil.move(image_path, nontarget_img_path)  # 移动该样本图片到blank_img_path
                shutil.move(path, nontarget_label_path)  # 移动该样本图片的标签到blank_label_path
                break
            print('正在处理 %s' % path)


if __name__ == '__main__':
    catogory = ['bridge']

    root = r'H:\DOTA\dota\trainsplit'
    img_path = os.path.join(root, 'images')  # 分割后的样本集
    label_path = os.path.join(root, 'labelTxt')  # 分割后的标签
    ext = '.png'  # 图片的后缀
    label_ext = '.txt'
    # 空白的样本及标签
    nontarget_img_path = os.path.join(root, 'nontarget_images')
    nontarget_label_path = os.path.join(root, 'nontarget_labelTxt')
    if not os.path.exists(nontarget_img_path):
        os.makedirs(nontarget_img_path)
    if not os.path.exists(nontarget_label_path):
        os.makedirs(nontarget_label_path)

    label_list = GetFileFromThisRootDir(label_path)
    for path in label_list:
        cleandata(path, img_path, nontarget_label_path, nontarget_img_path, ext, label_ext)
  1. 将dota数据集标签格式从txt转换成xml
import os
import cv2
from xml.dom.minidom import Document

category_set = ['bridge']
def custombasename(fullname):
    return os.path.basename(os.path.splitext(fullname)[0])

def limit_value(a, b):
    if a < 1:
        a = 1
    if a >= b:
        a = b - 1
    return a

def readlabeltxt(txtpath, height, width, hbb=True):
    print(txtpath)
    with open(txtpath, 'r') as f_in:  # 打开txt文件
        lines = f_in.readlines()
        splitlines = [x.strip().split(' ') for x in lines]  # 根据空格分割
        boxes = []
        for i, splitline in enumerate(splitlines):
            # if i in [0, 1]:  # DOTA数据集前两行对于我们来说是无用的
            #     continue
            label = splitline[8]
            if label not in category_set:  # 只书写制定的类别
                continue
            x1 = int(float(splitline[0]))
            y1 = int(float(splitline[1]))
            x2 = int(float(splitline[2]))
            y2 = int(float(splitline[3]))
            x3 = int(float(splitline[4]))
            y3 = int(float(splitline[5]))
            x4 = int(float(splitline[6]))
            y4 = int(float(splitline[7]))
            # 如果是hbb
            if hbb:
                xx1 = min(x1, x2, x3, x4)
                xx2 = max(x1, x2, x3, x4)
                yy1 = min(y1, y2, y3, y4)
                yy2 = max(y1, y2, y3, y4)

                xx1 = limit_value(xx1, width)
                xx2 = limit_value(xx2, width)
                yy1 = limit_value(yy1, height)
                yy2 = limit_value(yy2, height)

                box = [xx1, yy1, xx2, yy2, label]
                boxes.append(box)
            else:  # 否则是obb
                x1 = limit_value(x1, width)
                y1 = limit_value(y1, height)
                x2 = limit_value(x2, width)
                y2 = limit_value(y2, height)
                x3 = limit_value(x3, width)
                y3 = limit_value(y3, height)
                x4 = limit_value(x4, width)
                y4 = limit_value(y4, height)

                box = [x1, y1, x2, y2, x3, y3, x4, y4, label]
                boxes.append(box)
    return boxes


def writeXml(tmp, imgname, w, h, d, bboxes, hbb=True):
    doc = Document()
    # owner
    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)
    # owner
    folder = doc.createElement('folder')
    annotation.appendChild(folder)
    folder_txt = doc.createTextNode("VOC2007")
    folder.appendChild(folder_txt)

    filename = doc.createElement('filename')
    annotation.appendChild(filename)
    filename_txt = doc.createTextNode(imgname)
    filename.appendChild(filename_txt)
    # ones#
    source = doc.createElement('source')
    annotation.appendChild(source)

    database = doc.createElement('database')
    source.appendChild(database)
    database_txt = doc.createTextNode("My Database")
    database.appendChild(database_txt)

    annotation_new = doc.createElement('annotation')
    source.appendChild(annotation_new)
    annotation_new_txt = doc.createTextNode("VOC2007")
    annotation_new.appendChild(annotation_new_txt)

    image = doc.createElement('image')
    source.appendChild(image)
    image_txt = doc.createTextNode("flickr")
    image.appendChild(image_txt)
    # owner
    owner = doc.createElement('owner')
    annotation.appendChild(owner)

    flickrid = doc.createElement('flickrid')
    owner.appendChild(flickrid)
    flickrid_txt = doc.createTextNode("NULL")
    flickrid.appendChild(flickrid_txt)

    ow_name = doc.createElement('name')
    owner.appendChild(ow_name)
    ow_name_txt = doc.createTextNode("idannel")
    ow_name.appendChild(ow_name_txt)
    # onee#
    # twos#
    size = doc.createElement('size')
    annotation.appendChild(size)

    width = doc.createElement('width')
    size.appendChild(width)
    width_txt = doc.createTextNode(str(w))
    width.appendChild(width_txt)

    height = doc.createElement('height')
    size.appendChild(height)
    height_txt = doc.createTextNode(str(h))
    height.appendChild(height_txt)

    depth = doc.createElement('depth')
    size.appendChild(depth)
    depth_txt = doc.createTextNode(str(d))
    depth.appendChild(depth_txt)
    # twoe#
    segmented = doc.createElement('segmented')
    annotation.appendChild(segmented)
    segmented_txt = doc.createTextNode("0")
    segmented.appendChild(segmented_txt)

    for bbox in bboxes:
        # threes#
        object_new = doc.createElement("object")
        annotation.appendChild(object_new)

        name = doc.createElement('name')
        object_new.appendChild(name)
        name_txt = doc.createTextNode(str(bbox[-1]))
        name.appendChild(name_txt)

        pose = doc.createElement('pose')
        object_new.appendChild(pose)
        pose_txt = doc.createTextNode("Unspecified")
        pose.appendChild(pose_txt)

        truncated = doc.createElement('truncated')
        object_new.appendChild(truncated)
        truncated_txt = doc.createTextNode("0")
        truncated.appendChild(truncated_txt)

        difficult = doc.createElement('difficult')
        object_new.appendChild(difficult)
        difficult_txt = doc.createTextNode("0")
        difficult.appendChild(difficult_txt)
        # threes-1#
        bndbox = doc.createElement('bndbox')
        object_new.appendChild(bndbox)

        if hbb:
            xmin = doc.createElement('xmin')
            bndbox.appendChild(xmin)
            xmin_txt = doc.createTextNode(str(bbox[0]))
            xmin.appendChild(xmin_txt)

            ymin = doc.createElement('ymin')
            bndbox.appendChild(ymin)
            ymin_txt = doc.createTextNode(str(bbox[1]))
            ymin.appendChild(ymin_txt)

            xmax = doc.createElement('xmax')
            bndbox.appendChild(xmax)
            xmax_txt = doc.createTextNode(str(bbox[2]))
            xmax.appendChild(xmax_txt)

            ymax = doc.createElement('ymax')
            bndbox.appendChild(ymax)
            ymax_txt = doc.createTextNode(str(bbox[3]))
            ymax.appendChild(ymax_txt)
        else:
            x0 = doc.createElement('x0')
            bndbox.appendChild(x0)
            x0_txt = doc.createTextNode(str(bbox[0]))
            x0.appendChild(x0_txt)

            y0 = doc.createElement('y0')
            bndbox.appendChild(y0)
            y0_txt = doc.createTextNode(str(bbox[1]))
            y0.appendChild(y0_txt)

            x1 = doc.createElement('x1')
            bndbox.appendChild(x1)
            x1_txt = doc.createTextNode(str(bbox[2]))
            x1.appendChild(x1_txt)

            y1 = doc.createElement('y1')
            bndbox.appendChild(y1)
            y1_txt = doc.createTextNode(str(bbox[3]))
            y1.appendChild(y1_txt)

            x2 = doc.createElement('x2')
            bndbox.appendChild(x2)
            x2_txt = doc.createTextNode(str(bbox[4]))
            x2.appendChild(x2_txt)

            y2 = doc.createElement('y2')
            bndbox.appendChild(y2)
            y2_txt = doc.createTextNode(str(bbox[5]))
            y2.appendChild(y2_txt)

            x3 = doc.createElement('x3')
            bndbox.appendChild(x3)
            x3_txt = doc.createTextNode(str(bbox[6]))
            x3.appendChild(x3_txt)

            y3 = doc.createElement('y3')
            bndbox.appendChild(y3)
            y3_txt = doc.createTextNode(str(bbox[7]))
            y3.appendChild(y3_txt)

    xmlname = os.path.splitext(imgname)[0]
    tempfile = os.path.join(tmp, xmlname + '.xml')
    with open(tempfile, 'wb') as f:
        f.write(doc.toprettyxml(indent='\t', encoding='utf-8'))
    return


if __name__ == '__main__':
    data_path = r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit'
    images_path = os.path.join(data_path, 'images')  # 样本图片路径
    labeltxt_path = os.path.join(data_path, 'labelTxt')  # DOTA标签的所在路径
    anno_new_path = os.path.join(data_path, 'hbbxml')  # 新的voc格式存储位置(hbb形式)
    ext = '.png'  # 样本图片的后缀
    filenames = os.listdir(labeltxt_path)  # 获取每一个txt的名称
    for filename in filenames:
        filepath = labeltxt_path + '/' + filename  # 每一个DOTA标签的具体路径
        picname = os.path.splitext(filename)[0] + ext
        pic_path = os.path.join(images_path, picname)
        im = cv2.imread(pic_path)  # 读取相应的图片
        (H, W, D) = im.shape  # 返回样本的大小
        boxes = readlabeltxt(filepath, H, W, hbb=True)  # 默认是矩形(hbb)得到gt
        if len(boxes) == 0:
            print('文件为空', filepath)
        # 读取对应的样本图片,得到H,W,D用于书写xml

        # 书写xml
        writeXml(anno_new_path, picname, W, H, D, boxes, hbb=True)
        print('正在处理%s' % filename)

需要注意文件夹路径、目标类别、图像格式、注释框格式(hbb还是obb)

  1. xml到csv格式
import os
import glob   #文件操作相关模块,用它可以查找符合自己目的的文件
import pandas as pd
import xml.etree.ElementTree as ET

os.chdir(r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit\hbbxml')
path = r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit\hbbxml'         

def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):    #获得指定路径下所有的.XML文件
        tree = ET.parse(xml_file)    #分析指定的XML文件(获取XML文档对象 )
        root = tree.getroot()       #获取XML文档对象的根节点
        for member in root.findall('object'):
            value = (root.find('filename').text,    #获得文件名(图片名)
                     int(root.find('size')[0].text), #图片宽和高
                     int(root.find('size')[1].text),
                     member[0].text,               #类别
                     int(member[4][0].text),       #目标位置
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']   #csv各列名,
    xml_df = pd.DataFrame(xml_list, columns=column_name)   #第一个参数是待存放数据,后两个参数是行和列的名,可以使用list输入
    return xml_df


def main():
    image_path = path
    xml_df = xml_to_csv(image_path)
    xml_df.to_csv('label.csv', index=None)
    print('Successfully converted xml to csv.')


main()

需要注意的是
(1)、column_name = [‘filename’, ‘width’, ‘height’, ‘class’, ‘xmin’, ‘ymin’, ‘xmax’, ‘ymax’]与member中的元素的对应关系

  1. csv到tfrecord(用于tensorflow训练的格式)
"""
Usage:
  python csv_to_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train_label.record
  python csv_to_tfrecord.py --csv_input=data/val_labels.csv  --output_path=val_labels.record
"""
import os
import io
import pandas as pd
import tensorflow as tf

from PIL import Image
from collections import namedtuple, OrderedDict
from object_detection.utils import dataset_util

flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to the tfrecord output')
FLAGS = flags.FLAGS
os.chdir('C:\\Users\\DL-1\\models\\research\\object_detection\\')

# TO-DO replace this with label map
def class_text_to_int(row_label):
    if row_label == 'bridge':
        return 1
#    elif row_label == 'vehicle':
#        return 2
    else:
        None

def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]


def create_tf_example(group, path):
    with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_png = fid.read()
    encoded_png_io = io.BytesIO(encoded_png)
    image = Image.open(encoded_png_io)
    width, height = image.size

    filename = group.filename.encode('utf8')
    image_format = b'png'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        xmins.append(row['xmin'] / width)
        xmaxs.append(row['xmax'] / width)
        ymins.append(row['ymin'] / height)
        ymaxs.append(row['ymax'] / height)
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_png),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example


def main(_):
    writer = tf.io.TFRecordWriter(FLAGS.output_path)
    path = os.path.join(os.getcwd(), 'images/bridge_val')  # 获取当前工作目录
    examples = pd.read_csv(FLAGS.csv_input)
    grouped = split(examples, 'filename')
    for group in grouped:
        tf_example = create_tf_example(group, path)
        writer.write(tf_example.SerializeToString())

    writer.close()
    output_path = os.path.join(os.getcwd(), FLAGS.output_path)
    print('Successfully created the TFRecords: {}'.format(output_path))


if __name__ == '__main__':
    tf.app.run()

需要注意修改的地方是

(1)、图像的目录path = os.path.join(os.getcwd(), ‘images/bridge_val’) # 获取当前工作目录
(2)、图像后缀名(格式)image_format = b’png’
(3)、对应的目标类别。 if row_label == ‘bridge’:

7、devkit/dota_evaluation_task2.py

"""
    To use the code, users should to config detpath, annopath and imagesetfile
    detpath is the path for 15 result files, for the format, you can refer to "http://captain.whu.edu.cn/DOTAweb/tasks.html"
    search for PATH_TO_BE_CONFIGURED to config the paths
    Note, the evaluation is on the large scale images
"""
import xml.etree.ElementTree as ET
import os
#import cPickle
import numpy as np
import matplotlib.pyplot as plt

def parse_gt(filename):
    objects = []
    with open(filename, 'r') as f:
        lines = f.readlines()
        splitlines = [x.strip().split(' ')  for x in lines]
        for splitline in splitlines:
            object_struct = {}
            object_struct['name'] = splitline[8]
            if (len(splitline) == 9):
                object_struct['difficult'] = 0
            elif (len(splitline) == 10):
                object_struct['difficult'] = int(splitline[9])
            # object_struct['difficult'] = 0
            object_struct['bbox'] = [int(float(splitline[0])),
                                         int(float(splitline[1])),
                                         int(float(splitline[4])),
                                         int(float(splitline[5]))]
            w = int(float(splitline[4])) - int(float(splitline[0]))
            h = int(float(splitline[5])) - int(float(splitline[1]))
            object_struct['area'] = w * h
            #print('area:', object_struct['area'])
            # if object_struct['area'] < (15 * 15):
            #     #print('area:', object_struct['area'])
            #     object_struct['difficult'] = 1
            objects.append(object_struct)
    return objects
def voc_ap(rec, prec, use_07_metric=False):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

def voc_eval(detpath,
             annopath,
             imagesetfile,
             classname,
            # cachedir,
             ovthresh=0.5,
             use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])
    Top level function that does the PASCAL VOC evaluation.
    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.
    classname: Category name (duh)
    cachedir: Directory for caching the annotations
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name
    # cachedir caches the annotations in a pickle file

    # first load gt
    #if not os.path.isdir(cachedir):
     #   os.mkdir(cachedir)
    #cachefile = os.path.join(cachedir, 'annots.pkl')
    # read list of images
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
    imagenames = [x.strip() for x in lines]
    #print('imagenames: ', imagenames)
    #if not os.path.isfile(cachefile):
        # load annots
    recs = {}
    for i, imagename in enumerate(imagenames):
        #print('parse_files name: ', annopath.format(imagename))
        recs[imagename] = parse_gt(annopath.format(imagename))
        #if i % 100 == 0:
         #   print ('Reading annotation for {:d}/{:d}'.format(
          #      i + 1, len(imagenames)) )
        # save
        #print ('Saving cached annotations to {:s}'.format(cachefile))
        #with open(cachefile, 'w') as f:
         #   cPickle.dump(recs, f)
    #else:
        # load
        #with open(cachefile, 'r') as f:
         #   recs = cPickle.load(f)

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]
        bbox = np.array([x['bbox'] for x in R])
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {'bbox': bbox,
                                 'difficult': difficult,
                                 'det': det}

    # read dets
    detfile = detpath.format(classname)
    with open(detfile, 'r') as f:
        lines = f.readlines()

    splitlines = [x.strip().split(' ') for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])

    #print('check confidence: ', confidence)

    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])

    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    sorted_scores = np.sort(-confidence)

    #print('check sorted_scores: ', sorted_scores)
    #print('check sorted_ind: ', sorted_ind)
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]
    #print('check imge_ids: ', image_ids)
    #print('imge_ids len:', len(image_ids))
    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1., 0.)
            ih = np.maximum(iymax - iymin + 1., 0.)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                   (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                   (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            ## if there exist 2
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R['difficult'][jmax]:
                if not R['det'][jmax]:
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
                   # print('filename:', image_ids[d])
        else:
            fp[d] = 1.

    # compute precision recall

    print('check fp:', fp)
    print('check tp', tp)


    print('npos num:', npos)
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)

    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap

def main():
    # detpath = r'E:\documentation\OneDrive\documentation\DotaEvaluation\evluation_task2\evluation_task2\faster-rcnn-nms_0.3_task2\nms_0.3_task\Task2_{:s}.txt'
    # annopath = r'I:\dota\testset\ReclabelTxt-utf-8\{:s}.txt'
    # imagesetfile = r'I:\dota\testset\va.txt'

    detpath = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_dt\Task2_{:s}.txt'
    annopath = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_gt\{:s}.txt'# change the directory to the path of val/labelTxt, if you want to do evaluation on the valset
    imagesetfile = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_images\val_bridge_image.txt'

    classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
                'basketball-court', 'storage-tank',  'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter']
    classaps = []
    map = 0
    for classname in classnames:
        print('classname:', classname)
        rec, prec, ap = voc_eval(detpath,
             annopath,
             imagesetfile,
             classname,
             ovthresh=0.5,
             use_07_metric=True)
        map = map + ap
        #print('rec: ', rec, 'prec: ', prec, 'ap: ', ap)
        print('ap: ', ap)
        classaps.append(ap)

        ## uncomment to plot p-r curve for each category
        # plt.figure(figsize=(8,4))
        # plt.xlabel('recall')
        # plt.ylabel('precision')
        # plt.plot(rec, prec)
        # plt.show()
    map = map/len(classnames)
    print('map:', map)
    classaps = 100*np.array(classaps)
    print('classaps: ', classaps)
if __name__ == '__main__':
    main()

需要改动的地方主要有这三个:详情参考代码中的注释。

  • detpath = r’H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_dt\Task2_{
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

常用数据集预处理(dota) 的相关文章

  • 使用 pythonbrew 编译 Python 3.2 和 2.7 时出现问题

    我正在尝试使用构建多个版本的 python蟒蛇酿造 http pypi python org pypi pythonbrew 0 7 3 但我遇到了一些测试失败 这是在运行的虚拟机上 Ubuntu 8 04 32 位 当我使用时会发生这种情
  • 使用 psycopg2 在 python 中执行查询时出现“编程错误:语法错误位于或附近”

    我正在运行 Python v 2 7 和 psycopg2 v 2 5 我有一个 postgresql 数据库函数 它将 SQL 查询作为文本字段返回 我使用以下代码来调用该函数并从文本字段中提取查询 cur2 execute SELECT
  • 在 python 程序中合并第三方库的最佳实践是什么?

    下午好 我正在为我的工作编写一个中小型Python程序 该任务需要我使用 Excel 库xlwt and xlrd 以及一个用于查询 Oracle 数据库的库 称为CX Oracle 我正在通过版本控制系统 即CVS 开发该项目 我想知道围
  • Python 的键盘中断不会中止 Rust 函数 (PyO3)

    我有一个使用 PyO3 用 Rust 编写的 Python 库 它涉及一些昂贵的计算 单个函数调用最多需要 10 分钟 从 Python 调用时如何中止执行 Ctrl C 好像只有执行结束后才会处理 所以本质上没什么用 最小可重现示例 Ca
  • 将数据从 python pandas 数据框导出或写入 MS Access 表

    我正在尝试将数据从 python pandas 数据框导出到现有的 MS Access 表 我想用已更新的数据替换 MS Access 表 在 python 中 我尝试使用 pandas to sql 但收到错误消息 我觉得很奇怪 使用 p
  • Python(Selenium):如何通过登录重定向/组织登录登录网站

    我不是专业程序员 所以请原谅任何愚蠢的错误 我正在做一些研究 我正在尝试使用 Selenium 登录数据库来搜索大约 1000 个术语 我有两个问题 1 重定向到组织登录页面后如何使用 Selenium 登录 2 如何检索数据库 在我解决
  • 使用带有关键字参数的 map() 函数

    这是我尝试使用的循环map功能于 volume ids 1 2 3 4 5 ip 172 12 13 122 for volume id in volume ids my function volume id ip ip 我有办法做到这一点
  • Flask 会话变量

    我正在用 Flask 编写一个小型网络应用程序 当两个用户 在同一网络下 尝试使用应用程序时 我遇到会话变量问题 这是代码 import os from flask import Flask request render template
  • 是否可以忽略一行的pyright检查?

    我需要忽略一行的pyright 检查 有什么特别的评论吗 def create slog group SLogGroup data Optional dict None SLog insert one SLog group group da
  • Spark KMeans 无法处理大数据吗?

    KMeans 有几个参数training http spark apache org docs latest api python pyspark mllib html highlight kmeans pyspark mllib clus
  • 以编程方式停止Python脚本的执行? [复制]

    这个问题在这里已经有答案了 是否可以使用命令在任意行停止执行 python 脚本 Like some code quit quit at this point some more code that s not executed sys e
  • OpenCV 无法从 MacBook Pro iSight 捕获

    几天后 我无法再从 opencv 应用程序内部打开我的 iSight 相机 cap cv2 VideoCapture 0 返回 并且cap isOpened 回报true 然而 cap grab 刚刚返回false 有任何想法吗 示例代码
  • 如何使用 OpencV 从 Firebase 读取图像?

    有没有使用 OpenCV 从 Firebase 读取图像的想法 或者我必须先下载图片 然后从本地文件夹执行 cv imread 功能 有什么办法我可以使用cv imread link of picture from firebase 您可以
  • BeautifulSoup 中的嵌套标签 - Python

    我在网站和 stackoverflow 上查看了许多示例 但找不到解决我的问题的通用解决方案 我正在处理一个非常混乱的网站 我想抓取一些数据 标记看起来像这样 table tbody tr tr tr td td td table tr t
  • 如何在ipywidget按钮中显示全文?

    我正在创建一个ipywidget带有一些文本的按钮 但按钮中未显示全文 我使用的代码如下 import ipywidgets as widgets from IPython display import display button wid
  • Python 的“zip”内置函数的 Ruby 等价物是什么?

    Ruby 是否有与 Python 内置函数等效的东西zip功能 如果不是 做同样事情的简洁方法是什么 一些背景信息 当我试图找到一种干净的方法来进行涉及两个数组的检查时 出现了这个问题 如果我有zip 我可以写这样的东西 zip a b a
  • 对年龄列进行分组/分类

    我有一个数据框说df有一个柱子 Ages gt gt gt df Age 0 22 1 38 2 26 3 35 4 35 5 1 6 54 我想对这个年龄段进行分组并创建一个像这样的新专栏 If age gt 0 age lt 2 the
  • 解释 Python 中的数字范围

    在 Pylons Web 应用程序中 我需要获取一个字符串 例如 关于如何做到这一点有什么建议吗 我是 Python 新手 我还没有找到任何可以帮助解决此类问题的东西 该列表将是 1 2 3 45 46 48 49 50 51 77 使用
  • 在 Qt 中自动调整标签文本大小 - 奇怪的行为

    在 Qt 中 我有一个复合小部件 它由排列在 QBoxLayouts 内的多个 QLabels 组成 当小部件调整大小时 我希望标签文本缩放以填充标签区域 并且我已经在 resizeEvent 中实现了文本大小的调整 这可行 但似乎发生了某
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2

随机推荐