1. 从数据集中选出自己需要的类别
import os
import cv2
import shutilcatogary = ['bridge']   #列表def customname(fullname):"""返回不带后缀的文件名"""return os.path.basename(os.path.splitext(fullname)[0])def GetFileFromRoot(dir):"""获得每个文件的完整路径,包括后缀"""allfiles = []for root, dirs, files in os.walk(dir):for file in files:file_path = os.path.join(root, file)allfiles.append(file_path)return allfilesif __name__ == '__main__':root = 'E:/Aerial Images/Aerial Images/DOTA/train'raw_pic_path = os.path.join(root, 'images/images')raw_lab_path = os.path.join(root, 'labelTxt-v1.0/labelTxt')bridge_pic = os.path.join(root, 'bridge/images')bridge_lab = os.path.join(root, 'bridge/labelTxt')label_list = GetFileFromRoot(raw_lab_path)for label_path in label_list:n = 0f = open(label_path, 'r')lines = f.readlines()split_lines = (line.strip().split(' ') for line in lines) #strip 移除字符串头尾指定字符,默认空格,换行符或字符序列;根据空格来分割for i, split_line in enumerate(split_lines):if i in [0, 1]:    #标签文本前两行为格式及高度,无用continuecatogary_name = split_line[8]   #类别if catogary_name in catogary:n = n + 1if n > 1:    #所要求类别目标数量达到两个就可以将该图像挑选出来name = customname(label_path)  #不带后缀的标签文件名old_label_path = label_pathold_img_path = os.path.join(raw_pic_path, name + '.png')img = cv2.imread(old_img_path)new_lab_path = os.path.join(bridge_lab, name + 'txt')new_pic_path = os.path.join(bridge_pic, name + '.png')cv2.imwrite(new_pic_path, img)shutil.copyfile(old_label_path, new_lab_path)
  1. 删除数据集中的空白样本
import os
import shutil
import xml.dom.minidomdef custombasename(fullname):return os.path.basename(os.path.splitext(fullname)[0])def GetFileFromThisRootDir(dir,ext = None):allfiles = []needExtFilter = (ext != None)for root,dirs,files in os.walk(dir):for filespath in files:filepath = os.path.join(root, filespath)extension = os.path.splitext(filepath)[1][1:]if needExtFilter and extension in ext:allfiles.append(filepath)elif not needExtFilter:allfiles.append(filepath)return allfilesdef cleandata(path, img_path, blank_label_path, blank_img_path, ext, label_ext):name = custombasename(path)  #名称if label_ext == 'xml':DomTree = xml.dom.minidom.parse(path)  annotation = DomTree.documentElement  objectlist = annotation.getElementsByTagName('object')        if len(objectlist) == 0:image_path = os.path.join(img_path, name + ext) #样本图片的名称shutil.move(image_path, blank_img_path)  #移动该样本图片到blank_img_pathshutil.move(path, blank_label_path)     #移动该样本图片的标签到blank_label_pathelse:f_in =  open(path, 'r')  #打开label文件lines = f_in.readlines()if len(lines) == 0:  #如果为空f_in.close()image_path = os.path.join(img_path, name + ext) #样本图片的名称shutil.move(image_path, blank_img_path)  #移动该样本图片到blank_img_pathshutil.move(path, blank_label_path)     #移动该样本图片的标签到blank_label_pathprint('正在处理 %s'%path)if __name__ == '__main__':root = 'E:/Aerial Images/Aerial Images/trainsplit'img_path = os.path.join(root, 'images')  #分割后的样本集label_path = os.path.join(root, 'labelTxt')  #分割后的标签ext = '.png' #图片的后缀label_ext = '.txt'#空白的样本及标签blank_img_path = os.path.join(root, 'blank_images')blank_label_path = os.path.join(root, 'blank_labelTxt')if not os.path.exists(blank_img_path):os.makedirs(blank_img_path)if not os.path.exists(blank_label_path):os.makedirs(blank_label_path)label_list = GetFileFromThisRootDir(label_path)for path in label_list:cleandata(path, img_path, blank_label_path, blank_img_path, ext, label_ext)
  1. 删除数据中的非目标样本(提取出含所需目标的样本)
import os
import shutil
import xml.dom.minidom#n = 0def custombasename(fullname):return os.path.basename(os.path.splitext(fullname)[0])def GetFileFromThisRootDir(dir, ext=None):allfiles = []needExtFilter = (ext != None)for root, dirs, files in os.walk(dir):for filespath in files:filepath = os.path.join(root, filespath)extension = os.path.splitext(filepath)[1][1:]if needExtFilter and extension in ext:allfiles.append(filepath)elif not needExtFilter:allfiles.append(filepath)return allfilesdef cleandata(path, img_path, nontarget_label_path, nontarget_img_path, ext, label_ext):name = custombasename(path)  # 名称n = 0f_in = open(path, 'r')  # 打开label文件lines = f_in.readlines()splitlines = [line.strip().split(' ') for line in lines]for i, splitline in enumerate(splitlines):catogory_name = splitline[8]if catogory_name in catogory:n = n + 1if n > 0:f_in.close()image_path = os.path.join(img_path, name + ext)  # 样本图片的名称shutil.move(image_path, nontarget_img_path)  # 移动该样本图片到blank_img_pathshutil.move(path, nontarget_label_path)  # 移动该样本图片的标签到blank_label_pathbreakprint('正在处理 %s' % path)if __name__ == '__main__':catogory = ['bridge']root = r'H:\DOTA\dota\trainsplit'img_path = os.path.join(root, 'images')  # 分割后的样本集label_path = os.path.join(root, 'labelTxt')  # 分割后的标签ext = '.png'  # 图片的后缀label_ext = '.txt'# 空白的样本及标签nontarget_img_path = os.path.join(root, 'nontarget_images')nontarget_label_path = os.path.join(root, 'nontarget_labelTxt')if not os.path.exists(nontarget_img_path):os.makedirs(nontarget_img_path)if not os.path.exists(nontarget_label_path):os.makedirs(nontarget_label_path)label_list = GetFileFromThisRootDir(label_path)for path in label_list:cleandata(path, img_path, nontarget_label_path, nontarget_img_path, ext, label_ext)
  1. 将dota数据集标签格式从txt转换成xml
import os
import cv2
from xml.dom.minidom import Documentcategory_set = ['bridge']
def custombasename(fullname):return os.path.basename(os.path.splitext(fullname)[0])def limit_value(a, b):if a < 1:a = 1if a >= b:a = b - 1return adef readlabeltxt(txtpath, height, width, hbb=True):print(txtpath)with open(txtpath, 'r') as f_in:  # 打开txt文件lines = f_in.readlines()splitlines = [x.strip().split(' ') for x in lines]  # 根据空格分割boxes = []for i, splitline in enumerate(splitlines):# if i in [0, 1]:  # DOTA数据集前两行对于我们来说是无用的#     continuelabel = splitline[8]if label not in category_set:  # 只书写制定的类别continuex1 = int(float(splitline[0]))y1 = int(float(splitline[1]))x2 = int(float(splitline[2]))y2 = int(float(splitline[3]))x3 = int(float(splitline[4]))y3 = int(float(splitline[5]))x4 = int(float(splitline[6]))y4 = int(float(splitline[7]))# 如果是hbbif hbb:xx1 = min(x1, x2, x3, x4)xx2 = max(x1, x2, x3, x4)yy1 = min(y1, y2, y3, y4)yy2 = max(y1, y2, y3, y4)xx1 = limit_value(xx1, width)xx2 = limit_value(xx2, width)yy1 = limit_value(yy1, height)yy2 = limit_value(yy2, height)box = [xx1, yy1, xx2, yy2, label]boxes.append(box)else:  # 否则是obbx1 = limit_value(x1, width)y1 = limit_value(y1, height)x2 = limit_value(x2, width)y2 = limit_value(y2, height)x3 = limit_value(x3, width)y3 = limit_value(y3, height)x4 = limit_value(x4, width)y4 = limit_value(y4, height)box = [x1, y1, x2, y2, x3, y3, x4, y4, label]boxes.append(box)return boxesdef writeXml(tmp, imgname, w, h, d, bboxes, hbb=True):doc = Document()# ownerannotation = doc.createElement('annotation')doc.appendChild(annotation)# ownerfolder = doc.createElement('folder')annotation.appendChild(folder)folder_txt = doc.createTextNode("VOC2007")folder.appendChild(folder_txt)filename = doc.createElement('filename')annotation.appendChild(filename)filename_txt = doc.createTextNode(imgname)filename.appendChild(filename_txt)# ones#source = doc.createElement('source')annotation.appendChild(source)database = doc.createElement('database')source.appendChild(database)database_txt = doc.createTextNode("My Database")database.appendChild(database_txt)annotation_new = doc.createElement('annotation')source.appendChild(annotation_new)annotation_new_txt = doc.createTextNode("VOC2007")annotation_new.appendChild(annotation_new_txt)image = doc.createElement('image')source.appendChild(image)image_txt = doc.createTextNode("flickr")image.appendChild(image_txt)# ownerowner = doc.createElement('owner')annotation.appendChild(owner)flickrid = doc.createElement('flickrid')owner.appendChild(flickrid)flickrid_txt = doc.createTextNode("NULL")flickrid.appendChild(flickrid_txt)ow_name = doc.createElement('name')owner.appendChild(ow_name)ow_name_txt = doc.createTextNode("idannel")ow_name.appendChild(ow_name_txt)# onee## twos#size = doc.createElement('size')annotation.appendChild(size)width = doc.createElement('width')size.appendChild(width)width_txt = doc.createTextNode(str(w))width.appendChild(width_txt)height = doc.createElement('height')size.appendChild(height)height_txt = doc.createTextNode(str(h))height.appendChild(height_txt)depth = doc.createElement('depth')size.appendChild(depth)depth_txt = doc.createTextNode(str(d))depth.appendChild(depth_txt)# twoe#segmented = doc.createElement('segmented')annotation.appendChild(segmented)segmented_txt = doc.createTextNode("0")segmented.appendChild(segmented_txt)for bbox in bboxes:# threes#object_new = doc.createElement("object")annotation.appendChild(object_new)name = doc.createElement('name')object_new.appendChild(name)name_txt = doc.createTextNode(str(bbox[-1]))name.appendChild(name_txt)pose = doc.createElement('pose')object_new.appendChild(pose)pose_txt = doc.createTextNode("Unspecified")pose.appendChild(pose_txt)truncated = doc.createElement('truncated')object_new.appendChild(truncated)truncated_txt = doc.createTextNode("0")truncated.appendChild(truncated_txt)difficult = doc.createElement('difficult')object_new.appendChild(difficult)difficult_txt = doc.createTextNode("0")difficult.appendChild(difficult_txt)# threes-1#bndbox = doc.createElement('bndbox')object_new.appendChild(bndbox)if hbb:xmin = doc.createElement('xmin')bndbox.appendChild(xmin)xmin_txt = doc.createTextNode(str(bbox[0]))xmin.appendChild(xmin_txt)ymin = doc.createElement('ymin')bndbox.appendChild(ymin)ymin_txt = doc.createTextNode(str(bbox[1]))ymin.appendChild(ymin_txt)xmax = doc.createElement('xmax')bndbox.appendChild(xmax)xmax_txt = doc.createTextNode(str(bbox[2]))xmax.appendChild(xmax_txt)ymax = doc.createElement('ymax')bndbox.appendChild(ymax)ymax_txt = doc.createTextNode(str(bbox[3]))ymax.appendChild(ymax_txt)else:x0 = doc.createElement('x0')bndbox.appendChild(x0)x0_txt = doc.createTextNode(str(bbox[0]))x0.appendChild(x0_txt)y0 = doc.createElement('y0')bndbox.appendChild(y0)y0_txt = doc.createTextNode(str(bbox[1]))y0.appendChild(y0_txt)x1 = doc.createElement('x1')bndbox.appendChild(x1)x1_txt = doc.createTextNode(str(bbox[2]))x1.appendChild(x1_txt)y1 = doc.createElement('y1')bndbox.appendChild(y1)y1_txt = doc.createTextNode(str(bbox[3]))y1.appendChild(y1_txt)x2 = doc.createElement('x2')bndbox.appendChild(x2)x2_txt = doc.createTextNode(str(bbox[4]))x2.appendChild(x2_txt)y2 = doc.createElement('y2')bndbox.appendChild(y2)y2_txt = doc.createTextNode(str(bbox[5]))y2.appendChild(y2_txt)x3 = doc.createElement('x3')bndbox.appendChild(x3)x3_txt = doc.createTextNode(str(bbox[6]))x3.appendChild(x3_txt)y3 = doc.createElement('y3')bndbox.appendChild(y3)y3_txt = doc.createTextNode(str(bbox[7]))y3.appendChild(y3_txt)xmlname = os.path.splitext(imgname)[0]tempfile = os.path.join(tmp, xmlname + '.xml')with open(tempfile, 'wb') as f:f.write(doc.toprettyxml(indent='\t', encoding='utf-8'))returnif __name__ == '__main__':data_path = r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit'images_path = os.path.join(data_path, 'images')  # 样本图片路径labeltxt_path = os.path.join(data_path, 'labelTxt')  # DOTA标签的所在路径anno_new_path = os.path.join(data_path, 'hbbxml')  # 新的voc格式存储位置(hbb形式)ext = '.png'  # 样本图片的后缀filenames = os.listdir(labeltxt_path)  # 获取每一个txt的名称for filename in filenames:filepath = labeltxt_path + '/' + filename  # 每一个DOTA标签的具体路径picname = os.path.splitext(filename)[0] + extpic_path = os.path.join(images_path, picname)im = cv2.imread(pic_path)  # 读取相应的图片(H, W, D) = im.shape  # 返回样本的大小boxes = readlabeltxt(filepath, H, W, hbb=True)  # 默认是矩形(hbb)得到gtif len(boxes) == 0:print('文件为空', filepath)# 读取对应的样本图片,得到H,W,D用于书写xml# 书写xmlwriteXml(anno_new_path, picname, W, H, D, boxes, hbb=True)print('正在处理%s' % filename)

需要注意文件夹路径、目标类别、图像格式、注释框格式(hbb还是obb)

  1. xml到csv格式
import os
import glob   #文件操作相关模块,用它可以查找符合自己目的的文件
import pandas as pd
import xml.etree.ElementTree as ETos.chdir(r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit\hbbxml')
path = r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit\hbbxml'         def xml_to_csv(path):xml_list = []for xml_file in glob.glob(path + '/*.xml'):    #获得指定路径下所有的.XML文件tree = ET.parse(xml_file)    #分析指定的XML文件(获取XML文档对象 )root = tree.getroot()       #获取XML文档对象的根节点for member in root.findall('object'):value = (root.find('filename').text,    #获得文件名(图片名)int(root.find('size')[0].text), #图片宽和高int(root.find('size')[1].text),member[0].text,               #类别int(member[4][0].text),       #目标位置int(member[4][1].text),int(member[4][2].text),int(member[4][3].text))xml_list.append(value)column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']   #csv各列名,xml_df = pd.DataFrame(xml_list, columns=column_name)   #第一个参数是待存放数据,后两个参数是行和列的名,可以使用list输入return xml_dfdef main():image_path = pathxml_df = xml_to_csv(image_path)xml_df.to_csv('label.csv', index=None)print('Successfully converted xml to csv.')main()

需要注意的是
(1)、column_name = [‘filename’, ‘width’, ‘height’, ‘class’, ‘xmin’, ‘ymin’, ‘xmax’, ‘ymax’]与member中的元素的对应关系

  1. csv到tfrecord(用于tensorflow训练的格式)
"""
Usage:python csv_to_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train_label.recordpython csv_to_tfrecord.py --csv_input=data/val_labels.csv  --output_path=val_labels.record
"""
import os
import io
import pandas as pd
import tensorflow as tffrom PIL import Image
from collections import namedtuple, OrderedDict
from object_detection.utils import dataset_utilflags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to the tfrecord output')
FLAGS = flags.FLAGS
os.chdir('C:\\Users\\DL-1\\models\\research\\object_detection\\')# TO-DO replace this with label map
def class_text_to_int(row_label):if row_label == 'bridge':return 1
#    elif row_label == 'vehicle':
#        return 2else:Nonedef split(df, group):data = namedtuple('data', ['filename', 'object'])gb = df.groupby(group)return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_png = fid.read()encoded_png_io = io.BytesIO(encoded_png)image = Image.open(encoded_png_io)width, height = image.sizefilename = group.filename.encode('utf8')image_format = b'png'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width)xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_png),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(_):writer = tf.io.TFRecordWriter(FLAGS.output_path)path = os.path.join(os.getcwd(), 'images/bridge_val')  # 获取当前工作目录examples = pd.read_csv(FLAGS.csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()output_path = os.path.join(os.getcwd(), FLAGS.output_path)print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':tf.app.run()

需要注意修改的地方是

(1)、图像的目录path = os.path.join(os.getcwd(), ‘images/bridge_val’) # 获取当前工作目录
(2)、图像后缀名(格式)image_format = b’png’
(3)、对应的目标类别。 if row_label == ‘bridge’:

7、devkit/dota_evaluation_task2.py

"""To use the code, users should to config detpath, annopath and imagesetfiledetpath is the path for 15 result files, for the format, you can refer to "http://captain.whu.edu.cn/DOTAweb/tasks.html"search for PATH_TO_BE_CONFIGURED to config the pathsNote, the evaluation is on the large scale images
"""
import xml.etree.ElementTree as ET
import os
#import cPickle
import numpy as np
import matplotlib.pyplot as pltdef parse_gt(filename):objects = []with open(filename, 'r') as f:lines = f.readlines()splitlines = [x.strip().split(' ')  for x in lines]for splitline in splitlines:object_struct = {}object_struct['name'] = splitline[8]if (len(splitline) == 9):object_struct['difficult'] = 0elif (len(splitline) == 10):object_struct['difficult'] = int(splitline[9])# object_struct['difficult'] = 0object_struct['bbox'] = [int(float(splitline[0])),int(float(splitline[1])),int(float(splitline[4])),int(float(splitline[5]))]w = int(float(splitline[4])) - int(float(splitline[0]))h = int(float(splitline[5])) - int(float(splitline[1]))object_struct['area'] = w * h#print('area:', object_struct['area'])# if object_struct['area'] < (15 * 15):#     #print('area:', object_struct['area'])#     object_struct['difficult'] = 1objects.append(object_struct)return objects
def voc_ap(rec, prec, use_07_metric=False):""" ap = voc_ap(rec, prec, [use_07_metric])Compute VOC AP given precision and recall.If use_07_metric is true, uses theVOC 07 11 point method (default:False)."""if use_07_metric:# 11 point metricap = 0.for t in np.arange(0., 1.1, 0.1):if np.sum(rec >= t) == 0:p = 0else:p = np.max(prec[rec >= t])ap = ap + p / 11.else:# correct AP calculation# first append sentinel values at the endmrec = np.concatenate(([0.], rec, [1.]))mpre = np.concatenate(([0.], prec, [0.]))# compute the precision envelopefor i in range(mpre.size - 1, 0, -1):mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])# to calculate area under PR curve, look for points# where X axis (recall) changes valuei = np.where(mrec[1:] != mrec[:-1])[0]# and sum (\Delta recall) * precap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])return apdef voc_eval(detpath,annopath,imagesetfile,classname,# cachedir,ovthresh=0.5,use_07_metric=False):"""rec, prec, ap = voc_eval(detpath,annopath,imagesetfile,classname,[ovthresh],[use_07_metric])Top level function that does the PASCAL VOC evaluation.detpath: Path to detectionsdetpath.format(classname) should produce the detection results file.annopath: Path to annotationsannopath.format(imagename) should be the xml annotations file.imagesetfile: Text file containing the list of images, one image per line.classname: Category name (duh)cachedir: Directory for caching the annotations[ovthresh]: Overlap threshold (default = 0.5)[use_07_metric]: Whether to use VOC07's 11 point AP computation(default False)"""# assumes detections are in detpath.format(classname)# assumes annotations are in annopath.format(imagename)# assumes imagesetfile is a text file with each line an image name# cachedir caches the annotations in a pickle file# first load gt#if not os.path.isdir(cachedir):#   os.mkdir(cachedir)#cachefile = os.path.join(cachedir, 'annots.pkl')# read list of imageswith open(imagesetfile, 'r') as f:lines = f.readlines()imagenames = [x.strip() for x in lines]#print('imagenames: ', imagenames)#if not os.path.isfile(cachefile):# load annotsrecs = {}for i, imagename in enumerate(imagenames):#print('parse_files name: ', annopath.format(imagename))recs[imagename] = parse_gt(annopath.format(imagename))#if i % 100 == 0:#   print ('Reading annotation for {:d}/{:d}'.format(#      i + 1, len(imagenames)) )# save#print ('Saving cached annotations to {:s}'.format(cachefile))#with open(cachefile, 'w') as f:#   cPickle.dump(recs, f)#else:# load#with open(cachefile, 'r') as f:#   recs = cPickle.load(f)# extract gt objects for this classclass_recs = {}npos = 0for imagename in imagenames:R = [obj for obj in recs[imagename] if obj['name'] == classname]bbox = np.array([x['bbox'] for x in R])difficult = np.array([x['difficult'] for x in R]).astype(np.bool)det = [False] * len(R)npos = npos + sum(~difficult)class_recs[imagename] = {'bbox': bbox,'difficult': difficult,'det': det}# read detsdetfile = detpath.format(classname)with open(detfile, 'r') as f:lines = f.readlines()splitlines = [x.strip().split(' ') for x in lines]image_ids = [x[0] for x in splitlines]confidence = np.array([float(x[1]) for x in splitlines])#print('check confidence: ', confidence)BB = np.array([[float(z) for z in x[2:]] for x in splitlines])# sort by confidencesorted_ind = np.argsort(-confidence)sorted_scores = np.sort(-confidence)#print('check sorted_scores: ', sorted_scores)#print('check sorted_ind: ', sorted_ind)BB = BB[sorted_ind, :]image_ids = [image_ids[x] for x in sorted_ind]#print('check imge_ids: ', image_ids)#print('imge_ids len:', len(image_ids))# go down dets and mark TPs and FPsnd = len(image_ids)tp = np.zeros(nd)fp = np.zeros(nd)for d in range(nd):R = class_recs[image_ids[d]]bb = BB[d, :].astype(float)ovmax = -np.infBBGT = R['bbox'].astype(float)if BBGT.size > 0:# compute overlaps# intersectionixmin = np.maximum(BBGT[:, 0], bb[0])iymin = np.maximum(BBGT[:, 1], bb[1])ixmax = np.minimum(BBGT[:, 2], bb[2])iymax = np.minimum(BBGT[:, 3], bb[3])iw = np.maximum(ixmax - ixmin + 1., 0.)ih = np.maximum(iymax - iymin + 1., 0.)inters = iw * ih# unionuni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +(BBGT[:, 2] - BBGT[:, 0] + 1.) *(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)overlaps = inters / uniovmax = np.max(overlaps)## if there exist 2jmax = np.argmax(overlaps)if ovmax > ovthresh:if not R['difficult'][jmax]:if not R['det'][jmax]:tp[d] = 1.R['det'][jmax] = 1else:fp[d] = 1.# print('filename:', image_ids[d])else:fp[d] = 1.# compute precision recallprint('check fp:', fp)print('check tp', tp)print('npos num:', npos)fp = np.cumsum(fp)tp = np.cumsum(tp)rec = tp / float(npos)# avoid divide by zero in case the first detection matches a difficult# ground truthprec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)ap = voc_ap(rec, prec, use_07_metric)return rec, prec, apdef main():# detpath = r'E:\documentation\OneDrive\documentation\DotaEvaluation\evluation_task2\evluation_task2\faster-rcnn-nms_0.3_task2\nms_0.3_task\Task2_{:s}.txt'# annopath = r'I:\dota\testset\ReclabelTxt-utf-8\{:s}.txt'# imagesetfile = r'I:\dota\testset\va.txt'detpath = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_dt\Task2_{:s}.txt'annopath = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_gt\{:s}.txt'# change the directory to the path of val/labelTxt, if you want to do evaluation on the valsetimagesetfile = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_images\val_bridge_image.txt'classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court','basketball-court', 'storage-tank',  'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter']classaps = []map = 0for classname in classnames:print('classname:', classname)rec, prec, ap = voc_eval(detpath,annopath,imagesetfile,classname,ovthresh=0.5,use_07_metric=True)map = map + ap#print('rec: ', rec, 'prec: ', prec, 'ap: ', ap)print('ap: ', ap)classaps.append(ap)## uncomment to plot p-r curve for each category# plt.figure(figsize=(8,4))# plt.xlabel('recall')# plt.ylabel('precision')# plt.plot(rec, prec)# plt.show()map = map/len(classnames)print('map:', map)classaps = 100*np.array(classaps)print('classaps: ', classaps)
if __name__ == '__main__':main()

需要改动的地方主要有这三个:详情参考代码中的注释。