374 lines
15 KiB
Python
374 lines
15 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import yaml
|
|
from collections import OrderedDict
|
|
|
|
import paddle
|
|
from ppdet.data.source.category import get_categories
|
|
|
|
from ppdet.utils.logger import setup_logger
|
|
logger = setup_logger('ppdet.engine')
|
|
|
|
# Global dictionary
|
|
TRT_MIN_SUBGRAPH = {
|
|
'YOLO': 3,
|
|
'PPYOLOE': 3,
|
|
'SSD': 60,
|
|
'RCNN': 40,
|
|
'RetinaNet': 40,
|
|
'S2ANet': 80,
|
|
'EfficientDet': 40,
|
|
'Face': 3,
|
|
'TTFNet': 60,
|
|
'FCOS': 16,
|
|
'SOLOv2': 60,
|
|
'HigherHRNet': 3,
|
|
'HRNet': 3,
|
|
'DeepSORT': 3,
|
|
'ByteTrack': 10,
|
|
'CenterTrack': 5,
|
|
'JDE': 10,
|
|
'FairMOT': 5,
|
|
'GFL': 16,
|
|
'PicoDet': 3,
|
|
'CenterNet': 5,
|
|
'TOOD': 5,
|
|
'YOLOX': 8,
|
|
'YOLOF': 40,
|
|
'METRO_Body': 3,
|
|
'DETR': 3,
|
|
'CLRNet': 3
|
|
}
|
|
|
|
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
|
|
MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
|
|
LANE_ARCH = ['CLRNet']
|
|
|
|
TO_STATIC_SPEC = {
|
|
'yolov3_darknet53_270e_coco': [{
|
|
'im_id': paddle.static.InputSpec(
|
|
name='im_id', shape=[-1, 1], dtype='float32'),
|
|
'is_crowd': paddle.static.InputSpec(
|
|
name='is_crowd', shape=[-1, 50], dtype='float32'),
|
|
'gt_bbox': paddle.static.InputSpec(
|
|
name='gt_bbox', shape=[-1, 50, 4], dtype='float32'),
|
|
'curr_iter': paddle.static.InputSpec(
|
|
name='curr_iter', shape=[-1], dtype='float32'),
|
|
'image': paddle.static.InputSpec(
|
|
name='image', shape=[-1, 3, -1, -1], dtype='float32'),
|
|
'im_shape': paddle.static.InputSpec(
|
|
name='im_shape', shape=[-1, 2], dtype='float32'),
|
|
'scale_factor': paddle.static.InputSpec(
|
|
name='scale_factor', shape=[-1, 2], dtype='float32'),
|
|
'target0': paddle.static.InputSpec(
|
|
name='target0', shape=[-1, 3, 86, -1, -1], dtype='float32'),
|
|
'target1': paddle.static.InputSpec(
|
|
name='target1', shape=[-1, 3, 86, -1, -1], dtype='float32'),
|
|
'target2': paddle.static.InputSpec(
|
|
name='target2', shape=[-1, 3, 86, -1, -1], dtype='float32'),
|
|
}],
|
|
'tinypose_128x96': [{
|
|
'center': paddle.static.InputSpec(
|
|
name='center', shape=[-1, 2], dtype='float32'),
|
|
'scale': paddle.static.InputSpec(
|
|
name='scale', shape=[-1, 2], dtype='float32'),
|
|
'im_id': paddle.static.InputSpec(
|
|
name='im_id', shape=[-1, 1], dtype='float32'),
|
|
'image': paddle.static.InputSpec(
|
|
name='image', shape=[-1, 3, 128, 96], dtype='float32'),
|
|
'score': paddle.static.InputSpec(
|
|
name='score', shape=[-1], dtype='float32'),
|
|
'rotate': paddle.static.InputSpec(
|
|
name='rotate', shape=[-1], dtype='float32'),
|
|
'target': paddle.static.InputSpec(
|
|
name='target', shape=[-1, 17, 32, 24], dtype='float32'),
|
|
'target_weight': paddle.static.InputSpec(
|
|
name='target_weight', shape=[-1, 17, 1], dtype='float32'),
|
|
}],
|
|
'fcos_r50_fpn_1x_coco': [{
|
|
'im_id': paddle.static.InputSpec(
|
|
name='im_id', shape=[-1, 1], dtype='float32'),
|
|
'curr_iter': paddle.static.InputSpec(
|
|
name='curr_iter', shape=[-1], dtype='float32'),
|
|
'image': paddle.static.InputSpec(
|
|
name='image', shape=[-1, 3, -1, -1], dtype='float32'),
|
|
'im_shape': paddle.static.InputSpec(
|
|
name='im_shape', shape=[-1, 2], dtype='float32'),
|
|
'scale_factor': paddle.static.InputSpec(
|
|
name='scale_factor', shape=[-1, 2], dtype='float32'),
|
|
'reg_target0': paddle.static.InputSpec(
|
|
name='reg_target0', shape=[-1, 160, 160, 4], dtype='float32'),
|
|
'labels0': paddle.static.InputSpec(
|
|
name='labels0', shape=[-1, 160, 160, 1], dtype='int32'),
|
|
'centerness0': paddle.static.InputSpec(
|
|
name='centerness0', shape=[-1, 160, 160, 1], dtype='float32'),
|
|
'reg_target1': paddle.static.InputSpec(
|
|
name='reg_target1', shape=[-1, 80, 80, 4], dtype='float32'),
|
|
'labels1': paddle.static.InputSpec(
|
|
name='labels1', shape=[-1, 80, 80, 1], dtype='int32'),
|
|
'centerness1': paddle.static.InputSpec(
|
|
name='centerness1', shape=[-1, 80, 80, 1], dtype='float32'),
|
|
'reg_target2': paddle.static.InputSpec(
|
|
name='reg_target2', shape=[-1, 40, 40, 4], dtype='float32'),
|
|
'labels2': paddle.static.InputSpec(
|
|
name='labels2', shape=[-1, 40, 40, 1], dtype='int32'),
|
|
'centerness2': paddle.static.InputSpec(
|
|
name='centerness2', shape=[-1, 40, 40, 1], dtype='float32'),
|
|
'reg_target3': paddle.static.InputSpec(
|
|
name='reg_target3', shape=[-1, 20, 20, 4], dtype='float32'),
|
|
'labels3': paddle.static.InputSpec(
|
|
name='labels3', shape=[-1, 20, 20, 1], dtype='int32'),
|
|
'centerness3': paddle.static.InputSpec(
|
|
name='centerness3', shape=[-1, 20, 20, 1], dtype='float32'),
|
|
'reg_target4': paddle.static.InputSpec(
|
|
name='reg_target4', shape=[-1, 10, 10, 4], dtype='float32'),
|
|
'labels4': paddle.static.InputSpec(
|
|
name='labels4', shape=[-1, 10, 10, 1], dtype='int32'),
|
|
'centerness4': paddle.static.InputSpec(
|
|
name='centerness4', shape=[-1, 10, 10, 1], dtype='float32'),
|
|
}],
|
|
'picodet_s_320_coco_lcnet': [{
|
|
'im_id': paddle.static.InputSpec(
|
|
name='im_id', shape=[-1, 1], dtype='float32'),
|
|
'is_crowd': paddle.static.InputSpec(
|
|
name='is_crowd', shape=[-1, -1, 1], dtype='float32'),
|
|
'gt_class': paddle.static.InputSpec(
|
|
name='gt_class', shape=[-1, -1, 1], dtype='int32'),
|
|
'gt_bbox': paddle.static.InputSpec(
|
|
name='gt_bbox', shape=[-1, -1, 4], dtype='float32'),
|
|
'curr_iter': paddle.static.InputSpec(
|
|
name='curr_iter', shape=[-1], dtype='float32'),
|
|
'image': paddle.static.InputSpec(
|
|
name='image', shape=[-1, 3, -1, -1], dtype='float32'),
|
|
'im_shape': paddle.static.InputSpec(
|
|
name='im_shape', shape=[-1, 2], dtype='float32'),
|
|
'scale_factor': paddle.static.InputSpec(
|
|
name='scale_factor', shape=[-1, 2], dtype='float32'),
|
|
'pad_gt_mask': paddle.static.InputSpec(
|
|
name='pad_gt_mask', shape=[-1, -1, 1], dtype='float32'),
|
|
}],
|
|
'ppyoloe_crn_s_300e_coco': [{
|
|
'im_id': paddle.static.InputSpec(
|
|
name='im_id', shape=[-1, 1], dtype='float32'),
|
|
'is_crowd': paddle.static.InputSpec(
|
|
name='is_crowd', shape=[-1, -1, 1], dtype='float32'),
|
|
'gt_class': paddle.static.InputSpec(
|
|
name='gt_class', shape=[-1, -1, 1], dtype='int32'),
|
|
'gt_bbox': paddle.static.InputSpec(
|
|
name='gt_bbox', shape=[-1, -1, 4], dtype='float32'),
|
|
'curr_iter': paddle.static.InputSpec(
|
|
name='curr_iter', shape=[-1], dtype='float32'),
|
|
'image': paddle.static.InputSpec(
|
|
name='image', shape=[-1, 3, -1, -1], dtype='float32'),
|
|
'im_shape': paddle.static.InputSpec(
|
|
name='im_shape', shape=[-1, 2], dtype='float32'),
|
|
'scale_factor': paddle.static.InputSpec(
|
|
name='scale_factor', shape=[-1, 2], dtype='float32'),
|
|
'pad_gt_mask': paddle.static.InputSpec(
|
|
name='pad_gt_mask', shape=[-1, -1, 1], dtype='float32'),
|
|
}],
|
|
}
|
|
|
|
|
|
def apply_to_static(config, model):
|
|
filename = config.get('filename', None)
|
|
spec = TO_STATIC_SPEC.get(filename, None)
|
|
model = paddle.jit.to_static(model, input_spec=spec)
|
|
logger.info("Successfully to apply @to_static with specs: {}".format(spec))
|
|
return model
|
|
|
|
|
|
def _prune_input_spec(input_spec, program, targets):
|
|
# try to prune static program to figure out pruned input spec
|
|
# so we perform following operations in static mode
|
|
device = paddle.get_device()
|
|
paddle.enable_static()
|
|
paddle.set_device(device)
|
|
pruned_input_spec = [{}]
|
|
program = program.clone()
|
|
program = program._prune(targets=targets)
|
|
global_block = program.global_block()
|
|
for name, spec in input_spec[0].items():
|
|
try:
|
|
v = global_block.var(name)
|
|
pruned_input_spec[0][name] = spec
|
|
except Exception:
|
|
pass
|
|
paddle.disable_static(place=device)
|
|
return pruned_input_spec
|
|
|
|
|
|
def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
|
|
preprocess_list = []
|
|
label_list = []
|
|
if arch != "lane_arch":
|
|
anno_file = dataset_cfg.get_anno()
|
|
|
|
clsid2catid, catid2name = get_categories(metric, anno_file, arch)
|
|
|
|
label_list = [str(cat) for cat in catid2name.values()]
|
|
|
|
fuse_normalize = reader_cfg.get('fuse_normalize', False)
|
|
sample_transforms = reader_cfg['sample_transforms']
|
|
for st in sample_transforms[1:]:
|
|
for key, value in st.items():
|
|
p = {'type': key}
|
|
if key == 'Resize':
|
|
if int(image_shape[1]) != -1:
|
|
value['target_size'] = image_shape[1:]
|
|
value['interp'] = value.get('interp', 1) # cv2.INTER_LINEAR
|
|
if fuse_normalize and key == 'NormalizeImage':
|
|
continue
|
|
p.update(value)
|
|
preprocess_list.append(p)
|
|
batch_transforms = reader_cfg.get('batch_transforms', None)
|
|
if batch_transforms:
|
|
for bt in batch_transforms:
|
|
for key, value in bt.items():
|
|
# for deploy/infer, use PadStride(stride) instead PadBatch(pad_to_stride)
|
|
if key == 'PadBatch':
|
|
preprocess_list.append({
|
|
'type': 'PadStride',
|
|
'stride': value['pad_to_stride']
|
|
})
|
|
break
|
|
elif key == "CULaneResize":
|
|
# cut and resize
|
|
p = {'type': key}
|
|
p.update(value)
|
|
p.update({"cut_height": dataset_cfg.cut_height})
|
|
preprocess_list.append(p)
|
|
break
|
|
|
|
return preprocess_list, label_list
|
|
|
|
|
|
def _parse_tracker(tracker_cfg):
|
|
tracker_params = {}
|
|
for k, v in tracker_cfg.items():
|
|
tracker_params.update({k: v})
|
|
return tracker_params
|
|
|
|
|
|
def _dump_infer_config(config, path, image_shape, model):
|
|
arch_state = False
|
|
from ppdet.core.config.yaml_helpers import setup_orderdict
|
|
setup_orderdict()
|
|
use_dynamic_shape = True if image_shape[2] == -1 else False
|
|
infer_cfg = OrderedDict({
|
|
'mode': 'paddle',
|
|
'draw_threshold': 0.5,
|
|
'metric': config['metric'],
|
|
'use_dynamic_shape': use_dynamic_shape
|
|
})
|
|
export_onnx = config.get('export_onnx', False)
|
|
export_eb = config.get('export_eb', False)
|
|
|
|
infer_arch = config['architecture']
|
|
if 'RCNN' in infer_arch and export_onnx:
|
|
logger.warning(
|
|
"Exporting RCNN model to ONNX only support batch_size = 1")
|
|
infer_cfg['export_onnx'] = True
|
|
infer_cfg['export_eb'] = export_eb
|
|
|
|
if infer_arch in MOT_ARCH:
|
|
if infer_arch == 'DeepSORT':
|
|
tracker_cfg = config['DeepSORTTracker']
|
|
elif infer_arch == 'CenterTrack':
|
|
tracker_cfg = config['CenterTracker']
|
|
else:
|
|
tracker_cfg = config['JDETracker']
|
|
infer_cfg['tracker'] = _parse_tracker(tracker_cfg)
|
|
|
|
for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
|
|
if arch in infer_arch:
|
|
infer_cfg['arch'] = arch
|
|
infer_cfg['min_subgraph_size'] = min_subgraph_size
|
|
arch_state = True
|
|
break
|
|
|
|
if infer_arch == 'PPYOLOEWithAuxHead':
|
|
infer_arch = 'PPYOLOE'
|
|
|
|
if infer_arch in ['PPYOLOE', 'YOLOX', 'YOLOF']:
|
|
infer_cfg['arch'] = infer_arch
|
|
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
|
|
arch_state = True
|
|
|
|
if not arch_state:
|
|
logger.error(
|
|
'Architecture: {} is not supported for exporting model now.\n'.
|
|
format(infer_arch) +
|
|
'Please set TRT_MIN_SUBGRAPH in ppdet/engine/export_utils.py')
|
|
os._exit(0)
|
|
if 'mask_head' in config[config['architecture']] and config[config[
|
|
'architecture']]['mask_head']:
|
|
infer_cfg['mask'] = True
|
|
label_arch = 'detection_arch'
|
|
if infer_arch in KEYPOINT_ARCH:
|
|
label_arch = 'keypoint_arch'
|
|
|
|
if infer_arch in LANE_ARCH:
|
|
infer_cfg['arch'] = infer_arch
|
|
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
|
|
infer_cfg['img_w'] = config['img_w']
|
|
infer_cfg['ori_img_h'] = config['ori_img_h']
|
|
infer_cfg['cut_height'] = config['cut_height']
|
|
label_arch = 'lane_arch'
|
|
head_name = "CLRHead"
|
|
infer_cfg['conf_threshold'] = config[head_name]['conf_threshold']
|
|
infer_cfg['nms_thres'] = config[head_name]['nms_thres']
|
|
infer_cfg['max_lanes'] = config[head_name]['max_lanes']
|
|
infer_cfg['num_points'] = config[head_name]['num_points']
|
|
arch_state = True
|
|
|
|
if infer_arch in MOT_ARCH:
|
|
if config['metric'] in ['COCO', 'VOC']:
|
|
# MOT model run as Detector
|
|
reader_cfg = config['TestReader']
|
|
dataset_cfg = config['TestDataset']
|
|
else:
|
|
# 'metric' in ['MOT', 'MCMOT', 'KITTI']
|
|
label_arch = 'mot_arch'
|
|
reader_cfg = config['TestMOTReader']
|
|
dataset_cfg = config['TestMOTDataset']
|
|
else:
|
|
reader_cfg = config['TestReader']
|
|
dataset_cfg = config['TestDataset']
|
|
|
|
infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
|
|
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
|
|
|
|
if infer_arch == 'PicoDet':
|
|
if hasattr(config, 'export') and config['export'].get(
|
|
'post_process',
|
|
False) and not config['export'].get('benchmark', False):
|
|
infer_cfg['arch'] = 'GFL'
|
|
head_name = 'PicoHeadV2' if config['PicoHeadV2'] else 'PicoHead'
|
|
infer_cfg['NMS'] = config[head_name]['nms']
|
|
# In order to speed up the prediction, the threshold of nms
|
|
# is adjusted here, which can be changed in infer_cfg.yml
|
|
config[head_name]['nms']["score_threshold"] = 0.3
|
|
config[head_name]['nms']["nms_threshold"] = 0.5
|
|
infer_cfg['fpn_stride'] = config[head_name]['fpn_stride']
|
|
|
|
yaml.dump(infer_cfg, open(path, 'w'))
|
|
logger.info("Export inference config file to {}".format(os.path.join(path)))
|