更换文档检测模型
This commit is contained in:
99
paddle_detection/deploy/end2end_ppyoloe/README.md
Normal file
99
paddle_detection/deploy/end2end_ppyoloe/README.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# Export ONNX Model
|
||||
## Download pretrain paddle models
|
||||
|
||||
* [ppyoloe-s](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams)
|
||||
* [ppyoloe-m](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_m_300e_coco.pdparams)
|
||||
* [ppyoloe-l](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams)
|
||||
* [ppyoloe-x](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_x_300e_coco.pdparams)
|
||||
* [ppyoloe-s-400e](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_400e_coco.pdparams)
|
||||
|
||||
|
||||
## Export paddle model for deploying
|
||||
|
||||
```shell
|
||||
python ./tools/export_model.py \
|
||||
-c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml \
|
||||
-o weights=ppyoloe_crn_s_300e_coco.pdparams \
|
||||
trt=True \
|
||||
exclude_nms=True \
|
||||
TestReader.inputs_def.image_shape=[3,640,640] \
|
||||
--output_dir ./
|
||||
|
||||
# if you want to try ppyoloe-s-400e model
|
||||
python ./tools/export_model.py \
|
||||
-c configs/ppyoloe/ppyoloe_crn_s_400e_coco.yml \
|
||||
-o weights=ppyoloe_crn_s_400e_coco.pdparams \
|
||||
trt=True \
|
||||
exclude_nms=True \
|
||||
TestReader.inputs_def.image_shape=[3,640,640] \
|
||||
--output_dir ./
|
||||
```
|
||||
|
||||
## Check requirements
|
||||
```shell
|
||||
pip install onnx>=1.10.0
|
||||
pip install paddle2onnx
|
||||
pip install onnx-simplifier
|
||||
pip install onnx-graphsurgeon --index-url https://pypi.ngc.nvidia.com
|
||||
# if use cuda-python infer, please install it
|
||||
pip install cuda-python
|
||||
# if use cupy infer, please install it
|
||||
pip install cupy-cuda117 # cuda110-cuda117 are all available
|
||||
```
|
||||
|
||||
## Export script
|
||||
```shell
|
||||
python ./deploy/end2end_ppyoloe/end2end.py \
|
||||
--model-dir ppyoloe_crn_s_300e_coco \
|
||||
--save-file ppyoloe_crn_s_300e_coco.onnx \
|
||||
--opset 11 \
|
||||
--batch-size 1 \
|
||||
--topk-all 100 \
|
||||
--iou-thres 0.6 \
|
||||
--conf-thres 0.4
|
||||
# if you want to try ppyoloe-s-400e model
|
||||
python ./deploy/end2end_ppyoloe/end2end.py \
|
||||
--model-dir ppyoloe_crn_s_400e_coco \
|
||||
--save-file ppyoloe_crn_s_400e_coco.onnx \
|
||||
--opset 11 \
|
||||
--batch-size 1 \
|
||||
--topk-all 100 \
|
||||
--iou-thres 0.6 \
|
||||
--conf-thres 0.4
|
||||
```
|
||||
#### Description of all arguments
|
||||
|
||||
- `--model-dir` : the path of ppyoloe export dir.
|
||||
- `--save-file` : the path of export onnx.
|
||||
- `--opset` : onnx opset version.
|
||||
- `--img-size` : image size for exporting ppyoloe.
|
||||
- `--batch-size` : batch size for exporting ppyoloe.
|
||||
- `--topk-all` : topk objects for every image.
|
||||
- `--iou-thres` : iou threshold for NMS algorithm.
|
||||
- `--conf-thres` : confidence threshold for NMS algorithm.
|
||||
|
||||
### TensorRT backend (TensorRT version>= 8.0.0)
|
||||
#### TensorRT engine export
|
||||
``` shell
|
||||
/path/to/trtexec \
|
||||
--onnx=ppyoloe_crn_s_300e_coco.onnx \
|
||||
--saveEngine=ppyoloe_crn_s_300e_coco.engine \
|
||||
--fp16 # if export TensorRT fp16 model
|
||||
# if you want to try ppyoloe-s-400e model
|
||||
/path/to/trtexec \
|
||||
--onnx=ppyoloe_crn_s_400e_coco.onnx \
|
||||
--saveEngine=ppyoloe_crn_s_400e_coco.engine \
|
||||
--fp16 # if export TensorRT fp16 model
|
||||
```
|
||||
#### TensorRT image infer
|
||||
|
||||
``` shell
|
||||
# cuda-python infer script
|
||||
python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_300e_coco.engine
|
||||
# cupy infer script
|
||||
python ./deploy/end2end_ppyoloe/cupy-python.py ppyoloe_crn_s_300e_coco.engine
|
||||
# if you want to try ppyoloe-s-400e model
|
||||
python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_400e_coco.engine
|
||||
# or
|
||||
python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_400e_coco.engine
|
||||
```
|
||||
161
paddle_detection/deploy/end2end_ppyoloe/cuda-python.py
Normal file
161
paddle_detection/deploy/end2end_ppyoloe/cuda-python.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import sys
|
||||
import requests
|
||||
import cv2
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
from cuda import cudart
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
||||
# Resize and pad image while meeting stride-multiple constraints
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
|
||||
if auto: # minimum rectangle
|
||||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
||||
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
return im, r, (dw, dh)
|
||||
|
||||
|
||||
w = Path(sys.argv[1])
|
||||
|
||||
assert w.exists() and w.suffix in ('.engine', '.plan'), 'Wrong engine path'
|
||||
|
||||
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
||||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
||||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush']
|
||||
colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)}
|
||||
|
||||
url = 'https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg'
|
||||
file = requests.get(url)
|
||||
img = cv2.imdecode(np.frombuffer(file.content, np.uint8), 1)
|
||||
|
||||
_, stream = cudart.cudaStreamCreate()
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1)
|
||||
|
||||
# Infer TensorRT Engine
|
||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||
logger = trt.Logger(trt.Logger.ERROR)
|
||||
trt.init_libnvinfer_plugins(logger, namespace="")
|
||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(f.read())
|
||||
bindings = OrderedDict()
|
||||
fp16 = False # default updated below
|
||||
for index in range(model.num_bindings):
|
||||
name = model.get_binding_name(index)
|
||||
dtype = trt.nptype(model.get_binding_dtype(index))
|
||||
shape = tuple(model.get_binding_shape(index))
|
||||
data = np.empty(shape, dtype=np.dtype(dtype))
|
||||
_, data_ptr = cudart.cudaMallocAsync(data.nbytes, stream)
|
||||
bindings[name] = Binding(name, dtype, shape, data, data_ptr)
|
||||
if model.binding_is_input(index) and dtype == np.float16:
|
||||
fp16 = True
|
||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||
context = model.create_execution_context()
|
||||
|
||||
image = img.copy()
|
||||
image, ratio, dwdh = letterbox(image, auto=False)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
image_copy = image.copy()
|
||||
|
||||
image = image.transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, 0)
|
||||
image = np.ascontiguousarray(image)
|
||||
|
||||
im = image.astype(np.float32)
|
||||
im /= 255
|
||||
im -= mean
|
||||
im /= std
|
||||
|
||||
_, image_ptr = cudart.cudaMallocAsync(im.nbytes, stream)
|
||||
cudart.cudaMemcpyAsync(image_ptr, im.ctypes.data, im.nbytes,
|
||||
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
|
||||
|
||||
# warmup for 10 times
|
||||
for _ in range(10):
|
||||
tmp = np.random.randn(1, 3, 640, 640).astype(np.float32)
|
||||
_, tmp_ptr = cudart.cudaMallocAsync(tmp.nbytes, stream)
|
||||
binding_addrs['image'] = tmp_ptr
|
||||
context.execute_v2(list(binding_addrs.values()))
|
||||
|
||||
start = time.perf_counter()
|
||||
binding_addrs['image'] = image_ptr
|
||||
context.execute_v2(list(binding_addrs.values()))
|
||||
print(f'Cost {(time.perf_counter() - start) * 1000}ms')
|
||||
|
||||
nums = bindings['num_dets'].data
|
||||
boxes = bindings['det_boxes'].data
|
||||
scores = bindings['det_scores'].data
|
||||
classes = bindings['det_classes'].data
|
||||
|
||||
cudart.cudaMemcpyAsync(nums.ctypes.data,
|
||||
bindings['num_dets'].ptr,
|
||||
nums.nbytes,
|
||||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
|
||||
stream)
|
||||
cudart.cudaMemcpyAsync(boxes.ctypes.data,
|
||||
bindings['det_boxes'].ptr,
|
||||
boxes.nbytes,
|
||||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
|
||||
stream)
|
||||
cudart.cudaMemcpyAsync(scores.ctypes.data,
|
||||
bindings['det_scores'].ptr,
|
||||
scores.nbytes,
|
||||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
|
||||
stream)
|
||||
cudart.cudaMemcpyAsync(classes.ctypes.data,
|
||||
bindings['det_classes'].ptr,
|
||||
classes.data.nbytes,
|
||||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
|
||||
stream)
|
||||
|
||||
cudart.cudaStreamSynchronize(stream)
|
||||
cudart.cudaStreamDestroy(stream)
|
||||
|
||||
for i in binding_addrs.values():
|
||||
cudart.cudaFree(i)
|
||||
|
||||
num = int(nums[0][0])
|
||||
box_img = boxes[0, :num].round().astype(np.int32)
|
||||
score_img = scores[0, :num]
|
||||
clss_img = classes[0, :num]
|
||||
for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)):
|
||||
name = names[int(clss)]
|
||||
color = colors[name]
|
||||
cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2)
|
||||
cv2.putText(image_copy, name, (int(box[0]), int(box[1]) - 2), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.75, [225, 255, 255], thickness=2)
|
||||
|
||||
cv2.imshow('Result', cv2.cvtColor(image_copy, cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(0)
|
||||
131
paddle_detection/deploy/end2end_ppyoloe/cupy-python.py
Normal file
131
paddle_detection/deploy/end2end_ppyoloe/cupy-python.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import sys
|
||||
import requests
|
||||
import cv2
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
import tensorrt as trt
|
||||
from PIL import Image
|
||||
from collections import OrderedDict, namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
||||
# Resize and pad image while meeting stride-multiple constraints
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
|
||||
if auto: # minimum rectangle
|
||||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
||||
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
return im, r, (dw, dh)
|
||||
|
||||
|
||||
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
||||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
||||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush']
|
||||
colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)}
|
||||
|
||||
url = 'https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg'
|
||||
file = requests.get(url)
|
||||
img = cv2.imdecode(np.frombuffer(file.content, np.uint8), 1)
|
||||
|
||||
w = Path(sys.argv[1])
|
||||
|
||||
assert w.exists() and w.suffix in ('.engine', '.plan'), 'Wrong engine path'
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1)
|
||||
|
||||
mean = cp.asarray(mean)
|
||||
std = cp.asarray(std)
|
||||
|
||||
# Infer TensorRT Engine
|
||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
trt.init_libnvinfer_plugins(logger, namespace="")
|
||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(f.read())
|
||||
bindings = OrderedDict()
|
||||
fp16 = False # default updated below
|
||||
for index in range(model.num_bindings):
|
||||
name = model.get_binding_name(index)
|
||||
dtype = trt.nptype(model.get_binding_dtype(index))
|
||||
shape = tuple(model.get_binding_shape(index))
|
||||
data = cp.empty(shape, dtype=cp.dtype(dtype))
|
||||
bindings[name] = Binding(name, dtype, shape, data, int(data.data.ptr))
|
||||
if model.binding_is_input(index) and dtype == np.float16:
|
||||
fp16 = True
|
||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||
context = model.create_execution_context()
|
||||
|
||||
image = img.copy()
|
||||
image, ratio, dwdh = letterbox(image, auto=False)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
image_copy = image.copy()
|
||||
|
||||
image = image.transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, 0)
|
||||
image = np.ascontiguousarray(image)
|
||||
|
||||
im = cp.asarray(image)
|
||||
im = im.astype(cp.float32)
|
||||
im /= 255
|
||||
im -= mean
|
||||
im /= std
|
||||
|
||||
# warmup for 10 times
|
||||
for _ in range(10):
|
||||
tmp = cp.random.randn(1, 3, 640, 640).astype(cp.float32)
|
||||
binding_addrs['image'] = int(tmp.data.ptr)
|
||||
context.execute_v2(list(binding_addrs.values()))
|
||||
|
||||
start = time.perf_counter()
|
||||
binding_addrs['image'] = int(im.data.ptr)
|
||||
context.execute_v2(list(binding_addrs.values()))
|
||||
print(f'Cost {(time.perf_counter() - start) * 1000}ms')
|
||||
|
||||
nums = bindings['num_dets'].data
|
||||
boxes = bindings['det_boxes'].data
|
||||
scores = bindings['det_scores'].data
|
||||
classes = bindings['det_classes'].data
|
||||
|
||||
num = int(nums[0][0])
|
||||
box_img = boxes[0, :num].round().astype(cp.int32)
|
||||
score_img = scores[0, :num]
|
||||
clss_img = classes[0, :num]
|
||||
for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)):
|
||||
name = names[int(clss)]
|
||||
color = colors[name]
|
||||
cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2)
|
||||
cv2.putText(image_copy, name, (int(box[0]), int(box[1]) - 2), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.75, [225, 255, 255], thickness=2)
|
||||
|
||||
cv2.imshow('Result', cv2.cvtColor(image_copy, cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(0)
|
||||
97
paddle_detection/deploy/end2end_ppyoloe/end2end.py
Normal file
97
paddle_detection/deploy/end2end_ppyoloe/end2end.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from paddle2onnx.legacy.command import program2onnx
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def main(opt):
|
||||
model_dir = Path(opt.model_dir)
|
||||
save_file = Path(opt.save_file)
|
||||
assert model_dir.exists() and model_dir.is_dir()
|
||||
if save_file.is_dir():
|
||||
save_file = (save_file / model_dir.stem).with_suffix('.onnx')
|
||||
elif save_file.is_file() and save_file.suffix != '.onnx':
|
||||
save_file = save_file.with_suffix('.onnx')
|
||||
input_shape_dict = {'image': [opt.batch_size, 3, *opt.img_size],
|
||||
'scale_factor': [opt.batch_size, 2]}
|
||||
program2onnx(str(model_dir), str(save_file),
|
||||
'model.pdmodel', 'model.pdiparams',
|
||||
opt.opset, input_shape_dict=input_shape_dict)
|
||||
onnx_model = onnx.load(save_file)
|
||||
try:
|
||||
import onnxsim
|
||||
onnx_model, check = onnxsim.simplify(onnx_model)
|
||||
assert check, 'assert check failed'
|
||||
except Exception as e:
|
||||
print(f'Simplifier failure: {e}')
|
||||
onnx.checker.check_model(onnx_model)
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
graph.fold_constants()
|
||||
graph.cleanup().toposort()
|
||||
mul = concat = None
|
||||
for node in graph.nodes:
|
||||
if node.op == 'Div' and node.i(0).op == 'Mul':
|
||||
mul = node.i(0)
|
||||
if node.op == 'Concat' and node.o().op == 'Reshape' and node.o().o().op == 'ReduceSum':
|
||||
concat = node
|
||||
|
||||
assert mul.outputs[0].shape[1] == concat.outputs[0].shape[2], 'Something wrong in outputs shape'
|
||||
|
||||
anchors = mul.outputs[0].shape[1]
|
||||
classes = concat.outputs[0].shape[1]
|
||||
|
||||
scores = gs.Variable(name='scores', shape=[opt.batch_size, anchors, classes], dtype=np.float32)
|
||||
graph.layer(op='Transpose', name='lastTranspose',
|
||||
inputs=[concat.outputs[0]],
|
||||
outputs=[scores],
|
||||
attrs=OrderedDict(perm=[0, 2, 1]))
|
||||
|
||||
graph.inputs = [graph.inputs[0]]
|
||||
|
||||
attrs = OrderedDict(
|
||||
plugin_version="1",
|
||||
background_class=-1,
|
||||
max_output_boxes=opt.topk_all,
|
||||
score_threshold=opt.conf_thres,
|
||||
iou_threshold=opt.iou_thres,
|
||||
score_activation=False,
|
||||
box_coding=0, )
|
||||
outputs = [gs.Variable("num_dets", np.int32, [opt.batch_size, 1]),
|
||||
gs.Variable("det_boxes", np.float32, [opt.batch_size, opt.topk_all, 4]),
|
||||
gs.Variable("det_scores", np.float32, [opt.batch_size, opt.topk_all]),
|
||||
gs.Variable("det_classes", np.int32, [opt.batch_size, opt.topk_all])]
|
||||
graph.layer(op='EfficientNMS_TRT', name="batched_nms",
|
||||
inputs=[mul.outputs[0], scores],
|
||||
outputs=outputs,
|
||||
attrs=attrs)
|
||||
graph.outputs = outputs
|
||||
graph.cleanup().toposort()
|
||||
onnx.save(gs.export_onnx(graph), save_file)
|
||||
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model-dir', type=str,
|
||||
default=None,
|
||||
help='paddle static model')
|
||||
parser.add_argument('--save-file', type=str,
|
||||
default=None,
|
||||
help='onnx model save path')
|
||||
parser.add_argument('--opset', type=int, default=11, help='opset version')
|
||||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
|
||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
|
||||
opt = parser.parse_args()
|
||||
opt.img_size *= 2 if len(opt.img_size) == 1 else 1
|
||||
return opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = parse_opt()
|
||||
main(opt)
|
||||
Reference in New Issue
Block a user