更换文档检测模型
This commit is contained in:
266
paddle_detection/configs/rotate/tools/generate_result.py
Normal file
266
paddle_detection/configs/rotate/tools/generate_result.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
from shapely.geometry import Polygon
|
||||
import argparse
|
||||
|
||||
wordname_15 = [
|
||||
'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
|
||||
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
|
||||
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
|
||||
'harbor', 'swimming-pool', 'helicopter'
|
||||
]
|
||||
|
||||
wordname_16 = wordname_15 + ['container-crane']
|
||||
|
||||
wordname_18 = wordname_16 + ['airport', 'helipad']
|
||||
|
||||
DATA_CLASSES = {
|
||||
'dota10': wordname_15,
|
||||
'dota15': wordname_16,
|
||||
'dota20': wordname_18
|
||||
}
|
||||
|
||||
|
||||
def rbox_iou(g, p):
|
||||
"""
|
||||
iou of rbox
|
||||
"""
|
||||
g = np.array(g)
|
||||
p = np.array(p)
|
||||
g = Polygon(g[:8].reshape((4, 2)))
|
||||
p = Polygon(p[:8].reshape((4, 2)))
|
||||
g = g.buffer(0)
|
||||
p = p.buffer(0)
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
inter = Polygon(g).intersection(Polygon(p)).area
|
||||
union = g.area + p.area - inter
|
||||
if union == 0:
|
||||
return 0
|
||||
else:
|
||||
return inter / union
|
||||
|
||||
|
||||
def py_cpu_nms_poly_fast(dets, thresh):
|
||||
"""
|
||||
Args:
|
||||
dets: pred results
|
||||
thresh: nms threshold
|
||||
|
||||
Returns: index of keep
|
||||
"""
|
||||
obbs = dets[:, 0:-1]
|
||||
x1 = np.min(obbs[:, 0::2], axis=1)
|
||||
y1 = np.min(obbs[:, 1::2], axis=1)
|
||||
x2 = np.max(obbs[:, 0::2], axis=1)
|
||||
y2 = np.max(obbs[:, 1::2], axis=1)
|
||||
scores = dets[:, 8]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
polys = []
|
||||
for i in range(len(dets)):
|
||||
tm_polygon = [
|
||||
dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4],
|
||||
dets[i][5], dets[i][6], dets[i][7]
|
||||
]
|
||||
polys.append(tm_polygon)
|
||||
polys = np.array(polys)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
ovr = []
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
w = np.maximum(0.0, xx2 - xx1)
|
||||
h = np.maximum(0.0, yy2 - yy1)
|
||||
hbb_inter = w * h
|
||||
hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter)
|
||||
h_inds = np.where(hbb_ovr > 0)[0]
|
||||
tmp_order = order[h_inds + 1]
|
||||
for j in range(tmp_order.size):
|
||||
iou = rbox_iou(polys[i], polys[tmp_order[j]])
|
||||
hbb_ovr[h_inds[j]] = iou
|
||||
|
||||
try:
|
||||
if math.isnan(ovr[0]):
|
||||
pdb.set_trace()
|
||||
except:
|
||||
pass
|
||||
inds = np.where(hbb_ovr <= thresh)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def poly2origpoly(poly, x, y, rate):
|
||||
origpoly = []
|
||||
for i in range(int(len(poly) / 2)):
|
||||
tmp_x = float(poly[i * 2] + x) / float(rate)
|
||||
tmp_y = float(poly[i * 2 + 1] + y) / float(rate)
|
||||
origpoly.append(tmp_x)
|
||||
origpoly.append(tmp_y)
|
||||
return origpoly
|
||||
|
||||
|
||||
def nmsbynamedict(nameboxdict, nms, thresh):
|
||||
"""
|
||||
Args:
|
||||
nameboxdict: nameboxdict
|
||||
nms: nms
|
||||
thresh: nms threshold
|
||||
|
||||
Returns: nms result as dict
|
||||
"""
|
||||
nameboxnmsdict = {x: [] for x in nameboxdict}
|
||||
for imgname in nameboxdict:
|
||||
keep = nms(np.array(nameboxdict[imgname]), thresh)
|
||||
outdets = []
|
||||
for index in keep:
|
||||
outdets.append(nameboxdict[imgname][index])
|
||||
nameboxnmsdict[imgname] = outdets
|
||||
return nameboxnmsdict
|
||||
|
||||
|
||||
def merge_single(output_dir, nms, nms_thresh, pred_class_lst):
|
||||
"""
|
||||
Args:
|
||||
output_dir: output_dir
|
||||
nms: nms
|
||||
pred_class_lst: pred_class_lst
|
||||
class_name: class_name
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
class_name, pred_bbox_list = pred_class_lst
|
||||
nameboxdict = {}
|
||||
for line in pred_bbox_list:
|
||||
splitline = line.split(' ')
|
||||
subname = splitline[0]
|
||||
splitname = subname.split('__')
|
||||
oriname = splitname[0]
|
||||
pattern1 = re.compile(r'__\d+___\d+')
|
||||
x_y = re.findall(pattern1, subname)
|
||||
x_y_2 = re.findall(r'\d+', x_y[0])
|
||||
x, y = int(x_y_2[0]), int(x_y_2[1])
|
||||
|
||||
pattern2 = re.compile(r'__([\d+\.]+)__\d+___')
|
||||
|
||||
rate = re.findall(pattern2, subname)[0]
|
||||
|
||||
confidence = splitline[1]
|
||||
poly = list(map(float, splitline[2:]))
|
||||
origpoly = poly2origpoly(poly, x, y, rate)
|
||||
det = origpoly
|
||||
det.append(confidence)
|
||||
det = list(map(float, det))
|
||||
if (oriname not in nameboxdict):
|
||||
nameboxdict[oriname] = []
|
||||
nameboxdict[oriname].append(det)
|
||||
nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh)
|
||||
|
||||
# write result
|
||||
dstname = os.path.join(output_dir, class_name + '.txt')
|
||||
with open(dstname, 'w') as f_out:
|
||||
for imgname in nameboxnmsdict:
|
||||
for det in nameboxnmsdict[imgname]:
|
||||
confidence = det[-1]
|
||||
bbox = det[0:-1]
|
||||
outline = imgname + ' ' + str(confidence) + ' ' + ' '.join(
|
||||
map(str, bbox))
|
||||
f_out.write(outline + '\n')
|
||||
|
||||
|
||||
def generate_result(pred_txt_dir,
|
||||
output_dir='output',
|
||||
class_names=wordname_15,
|
||||
nms_thresh=0.1):
|
||||
"""
|
||||
pred_txt_dir: dir of pred txt
|
||||
output_dir: dir of output
|
||||
class_names: class names of data
|
||||
"""
|
||||
pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir))
|
||||
|
||||
# step1: summary pred bbox
|
||||
pred_classes = {}
|
||||
for class_name in class_names:
|
||||
pred_classes[class_name] = []
|
||||
|
||||
for current_txt in pred_txt_list:
|
||||
img_id = os.path.split(current_txt)[1]
|
||||
img_id = img_id.split('.txt')[0]
|
||||
with open(current_txt) as f:
|
||||
res = f.readlines()
|
||||
for item in res:
|
||||
item = item.split(' ')
|
||||
pred_class = item[0]
|
||||
item[0] = img_id
|
||||
pred_bbox = ' '.join(item)
|
||||
pred_classes[pred_class].append(pred_bbox)
|
||||
|
||||
pred_classes_lst = []
|
||||
for class_name in pred_classes.keys():
|
||||
print('class_name: {}, count: {}'.format(class_name,
|
||||
len(pred_classes[class_name])))
|
||||
pred_classes_lst.append((class_name, pred_classes[class_name]))
|
||||
|
||||
# step2: merge
|
||||
pool = Pool(len(class_names))
|
||||
nms = py_cpu_nms_poly_fast
|
||||
mergesingle_fn = partial(merge_single, output_dir, nms, nms_thresh)
|
||||
pool.map(mergesingle_fn, pred_classes_lst)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='generate test results')
|
||||
parser.add_argument('--pred_txt_dir', type=str, help='path of pred txt dir')
|
||||
parser.add_argument(
|
||||
'--output_dir', type=str, default='output', help='path of output dir')
|
||||
parser.add_argument(
|
||||
'--data_type', type=str, default='dota10', help='data type')
|
||||
parser.add_argument(
|
||||
'--nms_thresh',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='nms threshold while merging results')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
output_dir = args.output_dir
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
class_names = DATA_CLASSES[args.data_type]
|
||||
|
||||
generate_result(args.pred_txt_dir, output_dir, class_names)
|
||||
print('done!')
|
||||
Reference in New Issue
Block a user