更换文档检测模型
This commit is contained in:
97
paddle_detection/deploy/end2end_ppyoloe/end2end.py
Normal file
97
paddle_detection/deploy/end2end_ppyoloe/end2end.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from paddle2onnx.legacy.command import program2onnx
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def main(opt):
|
||||
model_dir = Path(opt.model_dir)
|
||||
save_file = Path(opt.save_file)
|
||||
assert model_dir.exists() and model_dir.is_dir()
|
||||
if save_file.is_dir():
|
||||
save_file = (save_file / model_dir.stem).with_suffix('.onnx')
|
||||
elif save_file.is_file() and save_file.suffix != '.onnx':
|
||||
save_file = save_file.with_suffix('.onnx')
|
||||
input_shape_dict = {'image': [opt.batch_size, 3, *opt.img_size],
|
||||
'scale_factor': [opt.batch_size, 2]}
|
||||
program2onnx(str(model_dir), str(save_file),
|
||||
'model.pdmodel', 'model.pdiparams',
|
||||
opt.opset, input_shape_dict=input_shape_dict)
|
||||
onnx_model = onnx.load(save_file)
|
||||
try:
|
||||
import onnxsim
|
||||
onnx_model, check = onnxsim.simplify(onnx_model)
|
||||
assert check, 'assert check failed'
|
||||
except Exception as e:
|
||||
print(f'Simplifier failure: {e}')
|
||||
onnx.checker.check_model(onnx_model)
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
graph.fold_constants()
|
||||
graph.cleanup().toposort()
|
||||
mul = concat = None
|
||||
for node in graph.nodes:
|
||||
if node.op == 'Div' and node.i(0).op == 'Mul':
|
||||
mul = node.i(0)
|
||||
if node.op == 'Concat' and node.o().op == 'Reshape' and node.o().o().op == 'ReduceSum':
|
||||
concat = node
|
||||
|
||||
assert mul.outputs[0].shape[1] == concat.outputs[0].shape[2], 'Something wrong in outputs shape'
|
||||
|
||||
anchors = mul.outputs[0].shape[1]
|
||||
classes = concat.outputs[0].shape[1]
|
||||
|
||||
scores = gs.Variable(name='scores', shape=[opt.batch_size, anchors, classes], dtype=np.float32)
|
||||
graph.layer(op='Transpose', name='lastTranspose',
|
||||
inputs=[concat.outputs[0]],
|
||||
outputs=[scores],
|
||||
attrs=OrderedDict(perm=[0, 2, 1]))
|
||||
|
||||
graph.inputs = [graph.inputs[0]]
|
||||
|
||||
attrs = OrderedDict(
|
||||
plugin_version="1",
|
||||
background_class=-1,
|
||||
max_output_boxes=opt.topk_all,
|
||||
score_threshold=opt.conf_thres,
|
||||
iou_threshold=opt.iou_thres,
|
||||
score_activation=False,
|
||||
box_coding=0, )
|
||||
outputs = [gs.Variable("num_dets", np.int32, [opt.batch_size, 1]),
|
||||
gs.Variable("det_boxes", np.float32, [opt.batch_size, opt.topk_all, 4]),
|
||||
gs.Variable("det_scores", np.float32, [opt.batch_size, opt.topk_all]),
|
||||
gs.Variable("det_classes", np.int32, [opt.batch_size, opt.topk_all])]
|
||||
graph.layer(op='EfficientNMS_TRT', name="batched_nms",
|
||||
inputs=[mul.outputs[0], scores],
|
||||
outputs=outputs,
|
||||
attrs=attrs)
|
||||
graph.outputs = outputs
|
||||
graph.cleanup().toposort()
|
||||
onnx.save(gs.export_onnx(graph), save_file)
|
||||
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model-dir', type=str,
|
||||
default=None,
|
||||
help='paddle static model')
|
||||
parser.add_argument('--save-file', type=str,
|
||||
default=None,
|
||||
help='onnx model save path')
|
||||
parser.add_argument('--opset', type=int, default=11, help='opset version')
|
||||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
|
||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
|
||||
opt = parser.parse_args()
|
||||
opt.img_size *= 2 if len(opt.img_size) == 1 else 1
|
||||
return opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = parse_opt()
|
||||
main(opt)
|
||||
Reference in New Issue
Block a user