更换文档检测模型
This commit is contained in:
23
paddle_detection/deploy/third_engine/demo_mnn/CMakeLists.txt
Normal file
23
paddle_detection/deploy/third_engine/demo_mnn/CMakeLists.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
cmake_minimum_required(VERSION 3.9)
|
||||
project(picodet-mnn)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(MNN_DIR PATHS "./mnn")
|
||||
|
||||
# find_package(OpenCV REQUIRED PATHS "/work/dependence/opencv/opencv-3.4.3/build")
|
||||
find_package(OpenCV REQUIRED)
|
||||
include_directories(
|
||||
${MNN_DIR}/include
|
||||
${MNN_DIR}/include/MNN
|
||||
${CMAKE_SOURCE_DIR}
|
||||
)
|
||||
link_directories(mnn/lib)
|
||||
|
||||
add_library(libMNN SHARED IMPORTED)
|
||||
set_target_properties(
|
||||
libMNN
|
||||
PROPERTIES IMPORTED_LOCATION
|
||||
${CMAKE_SOURCE_DIR}/mnn/lib/libMNN.so
|
||||
)
|
||||
add_executable(picodet-mnn main.cpp picodet_mnn.cpp)
|
||||
target_link_libraries(picodet-mnn MNN ${OpenCV_LIBS} libMNN.so)
|
||||
89
paddle_detection/deploy/third_engine/demo_mnn/README.md
Normal file
89
paddle_detection/deploy/third_engine/demo_mnn/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# PicoDet MNN Demo
|
||||
|
||||
本Demo提供的预测代码是根据[Alibaba's MNN framework](https://github.com/alibaba/MNN) 推理库预测的。
|
||||
|
||||
## C++ Demo
|
||||
|
||||
- 第一步:根据[MNN官方编译文档](https://www.yuque.com/mnn/en/build_linux) 编译生成预测库.
|
||||
- 第二步:编译或下载得到OpenCV库,可参考OpenCV官网,为了方便如果环境是gcc8.2 x86环境,可直接下载以下库:
|
||||
```shell
|
||||
wget https://paddledet.bj.bcebos.com/data/opencv-3.4.16_gcc8.2_ffmpeg.tar.gz
|
||||
tar -xf opencv-3.4.16_gcc8.2_ffmpeg.tar.gz
|
||||
```
|
||||
|
||||
- 第三步:准备模型
|
||||
```shell
|
||||
modelName=picodet_s_320_coco_lcnet
|
||||
# 导出Inference model
|
||||
python tools/export_model.py \
|
||||
-c configs/picodet/${modelName}.yml \
|
||||
-o weights=${modelName}.pdparams \
|
||||
--output_dir=inference_model
|
||||
# 转换到ONNX
|
||||
paddle2onnx --model_dir inference_model/${modelName} \
|
||||
--model_filename model.pdmodel \
|
||||
--params_filename model.pdiparams \
|
||||
--opset_version 11 \
|
||||
--save_file ${modelName}.onnx
|
||||
# 简化模型
|
||||
python -m onnxsim ${modelName}.onnx ${modelName}_processed.onnx
|
||||
# 将模型转换至MNN格式
|
||||
python -m MNN.tools.mnnconvert -f ONNX --modelFile picodet_s_320_lcnet_processed.onnx --MNNModel picodet_s_320_lcnet.mnn
|
||||
```
|
||||
为了快速测试,可直接下载:[picodet_s_320_lcnet.mnn](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_lcnet.mnn)(不带后处理)。
|
||||
|
||||
**注意:**由于MNN里,Matmul算子的输入shape如果不一致计算有问题,带后处理的Demo正在升级中,很快发布。
|
||||
|
||||
## 编译可执行程序
|
||||
|
||||
- 第一步:导入lib包
|
||||
```
|
||||
mkdir mnn && cd mnn && mkdir lib
|
||||
cp /path/to/MNN/build/libMNN.so .
|
||||
cd ..
|
||||
cp -r /path/to/MNN/include .
|
||||
```
|
||||
- 第二步:修改CMakeLists.txt中OpenCV和MNN的路径
|
||||
- 第三步:开始编译
|
||||
``` shell
|
||||
mkdir build && cd build
|
||||
cmake ..
|
||||
make
|
||||
```
|
||||
如果在build目录下生成`picodet-mnn`可执行文件,就证明成功了。
|
||||
|
||||
## 开始运行
|
||||
|
||||
首先新建预测结果存放目录:
|
||||
```shell
|
||||
cp -r ../demo_onnxruntime/imgs .
|
||||
cd build
|
||||
mkdir ../results
|
||||
```
|
||||
|
||||
- 预测一张图片
|
||||
``` shell
|
||||
./picodet-mnn 0 ../picodet_s_320_lcnet_3.mnn 320 320 ../imgs/dog.jpg
|
||||
```
|
||||
|
||||
-测试速度Benchmark
|
||||
|
||||
``` shell
|
||||
./picodet-mnn 1 ../picodet_s_320_lcnet.mnn 320 320
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
- 预测结果精度不对:
|
||||
请先确认模型输入shape是否对齐,并且模型输出name是否对齐,不带后处理的PicoDet增强版模型输出name如下:
|
||||
```shell
|
||||
# 分类分支 | 检测分支
|
||||
{"transpose_0.tmp_0", "transpose_1.tmp_0"},
|
||||
{"transpose_2.tmp_0", "transpose_3.tmp_0"},
|
||||
{"transpose_4.tmp_0", "transpose_5.tmp_0"},
|
||||
{"transpose_6.tmp_0", "transpose_7.tmp_0"},
|
||||
```
|
||||
可使用[netron](https://netron.app)查看具体name,并修改`picodet_mnn.hpp`中相应`non_postprocess_heads_info`数组。
|
||||
|
||||
## Reference
|
||||
[MNN](https://github.com/alibaba/MNN)
|
||||
203
paddle_detection/deploy/third_engine/demo_mnn/main.cpp
Normal file
203
paddle_detection/deploy/third_engine/demo_mnn/main.cpp
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "picodet_mnn.hpp"
|
||||
#include <iostream>
|
||||
#include <opencv2/core/core.hpp>
|
||||
#include <opencv2/highgui/highgui.hpp>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
|
||||
#define __SAVE_RESULT__ // if defined save drawed results to ../results, else
|
||||
// show it in windows
|
||||
|
||||
struct object_rect {
|
||||
int x;
|
||||
int y;
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
std::vector<int> GenerateColorMap(int num_class) {
|
||||
auto colormap = std::vector<int>(3 * num_class, 0);
|
||||
for (int i = 0; i < num_class; ++i) {
|
||||
int j = 0;
|
||||
int lab = i;
|
||||
while (lab) {
|
||||
colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j));
|
||||
colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
|
||||
colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
|
||||
++j;
|
||||
lab >>= 3;
|
||||
}
|
||||
}
|
||||
return colormap;
|
||||
}
|
||||
|
||||
void draw_bboxes(const cv::Mat &im, const std::vector<BoxInfo> &bboxes,
|
||||
std::string save_path = "None") {
|
||||
static const char *class_names[] = {
|
||||
"person", "bicycle", "car",
|
||||
"motorcycle", "airplane", "bus",
|
||||
"train", "truck", "boat",
|
||||
"traffic light", "fire hydrant", "stop sign",
|
||||
"parking meter", "bench", "bird",
|
||||
"cat", "dog", "horse",
|
||||
"sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe",
|
||||
"backpack", "umbrella", "handbag",
|
||||
"tie", "suitcase", "frisbee",
|
||||
"skis", "snowboard", "sports ball",
|
||||
"kite", "baseball bat", "baseball glove",
|
||||
"skateboard", "surfboard", "tennis racket",
|
||||
"bottle", "wine glass", "cup",
|
||||
"fork", "knife", "spoon",
|
||||
"bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli",
|
||||
"carrot", "hot dog", "pizza",
|
||||
"donut", "cake", "chair",
|
||||
"couch", "potted plant", "bed",
|
||||
"dining table", "toilet", "tv",
|
||||
"laptop", "mouse", "remote",
|
||||
"keyboard", "cell phone", "microwave",
|
||||
"oven", "toaster", "sink",
|
||||
"refrigerator", "book", "clock",
|
||||
"vase", "scissors", "teddy bear",
|
||||
"hair drier", "toothbrush"};
|
||||
|
||||
cv::Mat image = im.clone();
|
||||
int src_w = image.cols;
|
||||
int src_h = image.rows;
|
||||
int thickness = 2;
|
||||
auto colormap = GenerateColorMap(sizeof(class_names));
|
||||
|
||||
for (size_t i = 0; i < bboxes.size(); i++) {
|
||||
const BoxInfo &bbox = bboxes[i];
|
||||
std::cout << bbox.x1 << ". " << bbox.y1 << ". " << bbox.x2 << ". "
|
||||
<< bbox.y2 << ". " << std::endl;
|
||||
int c1 = colormap[3 * bbox.label + 0];
|
||||
int c2 = colormap[3 * bbox.label + 1];
|
||||
int c3 = colormap[3 * bbox.label + 2];
|
||||
cv::Scalar color = cv::Scalar(c1, c2, c3);
|
||||
// cv::Scalar color = cv::Scalar(0, 0, 255);
|
||||
cv::rectangle(image, cv::Rect(cv::Point(bbox.x1, bbox.y1),
|
||||
cv::Point(bbox.x2, bbox.y2)),
|
||||
color, 1, cv::LINE_AA);
|
||||
|
||||
char text[256];
|
||||
sprintf(text, "%s %.1f%%", class_names[bbox.label], bbox.score * 100);
|
||||
|
||||
int baseLine = 0;
|
||||
cv::Size label_size =
|
||||
cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
|
||||
|
||||
int x = bbox.x1;
|
||||
int y = bbox.y1 - label_size.height - baseLine;
|
||||
if (y < 0)
|
||||
y = 0;
|
||||
if (x + label_size.width > image.cols)
|
||||
x = image.cols - label_size.width;
|
||||
|
||||
cv::rectangle(image, cv::Rect(cv::Point(x, y),
|
||||
cv::Size(label_size.width,
|
||||
label_size.height + baseLine)),
|
||||
color, -1);
|
||||
|
||||
cv::putText(image, text, cv::Point(x, y + label_size.height),
|
||||
cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(255, 255, 255), 1,
|
||||
cv::LINE_AA);
|
||||
}
|
||||
|
||||
if (save_path == "None") {
|
||||
cv::imshow("image", image);
|
||||
} else {
|
||||
cv::imwrite(save_path, image);
|
||||
std::cout << save_path << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int image_demo(PicoDet &detector, const char *imagepath) {
|
||||
std::vector<cv::String> filenames;
|
||||
cv::glob(imagepath, filenames, false);
|
||||
|
||||
for (auto img_name : filenames) {
|
||||
cv::Mat image = cv::imread(img_name, cv::IMREAD_COLOR);
|
||||
if (image.empty()) {
|
||||
fprintf(stderr, "cv::imread %s failed\n", img_name.c_str());
|
||||
return -1;
|
||||
}
|
||||
std::vector<BoxInfo> results;
|
||||
detector.detect(image, results, false);
|
||||
std::cout << "detect done." << std::endl;
|
||||
|
||||
#ifdef __SAVE_RESULT__
|
||||
std::string save_path = img_name;
|
||||
draw_bboxes(image, results, save_path.replace(3, 4, "results"));
|
||||
#else
|
||||
draw_bboxes(image, results);
|
||||
cv::waitKey(0);
|
||||
#endif
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int benchmark(PicoDet &detector, int width, int height) {
|
||||
int loop_num = 100;
|
||||
int warm_up = 8;
|
||||
|
||||
double time_min = DBL_MAX;
|
||||
double time_max = -DBL_MAX;
|
||||
double time_avg = 0;
|
||||
cv::Mat image(width, height, CV_8UC3, cv::Scalar(1, 1, 1));
|
||||
for (int i = 0; i < warm_up + loop_num; i++) {
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
std::vector<BoxInfo> results;
|
||||
detector.detect(image, results, false);
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
|
||||
std::chrono::duration<double> elapsed = end - start;
|
||||
double time = elapsed.count();
|
||||
if (i >= warm_up) {
|
||||
time_min = (std::min)(time_min, time);
|
||||
time_max = (std::max)(time_max, time);
|
||||
time_avg += time;
|
||||
}
|
||||
}
|
||||
time_avg /= loop_num;
|
||||
fprintf(stderr, "%20s min = %7.2f max = %7.2f avg = %7.2f\n", "picodet",
|
||||
time_min, time_max, time_avg);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int mode = atoi(argv[1]);
|
||||
std::string model_path = argv[2];
|
||||
int height = 320;
|
||||
int width = 320;
|
||||
if (argc == 4) {
|
||||
height = atoi(argv[3]);
|
||||
width = atoi(argv[4]);
|
||||
}
|
||||
PicoDet detector = PicoDet(model_path, width, height, 4, 0.45, 0.3);
|
||||
if (mode == 1) {
|
||||
benchmark(detector, width, height);
|
||||
} else {
|
||||
if (argc != 5) {
|
||||
std::cout << "Must set image file, such as ./picodet-mnn 0 "
|
||||
"../picodet_s_320_lcnet.mnn 320 320 img.jpg"
|
||||
<< std::endl;
|
||||
}
|
||||
const char *images = argv[5];
|
||||
image_demo(detector, images);
|
||||
}
|
||||
}
|
||||
253
paddle_detection/deploy/third_engine/demo_mnn/picodet_mnn.cpp
Normal file
253
paddle_detection/deploy/third_engine/demo_mnn/picodet_mnn.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_mnn
|
||||
|
||||
#include "picodet_mnn.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
PicoDet::PicoDet(const std::string &mnn_path, int input_width, int input_length,
|
||||
int num_thread_, float score_threshold_,
|
||||
float nms_threshold_) {
|
||||
num_thread = num_thread_;
|
||||
in_w = input_width;
|
||||
in_h = input_length;
|
||||
score_threshold = score_threshold_;
|
||||
nms_threshold = nms_threshold_;
|
||||
|
||||
PicoDet_interpreter = std::shared_ptr<MNN::Interpreter>(
|
||||
MNN::Interpreter::createFromFile(mnn_path.c_str()));
|
||||
MNN::ScheduleConfig config;
|
||||
config.numThread = num_thread;
|
||||
MNN::BackendConfig backendConfig;
|
||||
backendConfig.precision = (MNN::BackendConfig::PrecisionMode)2;
|
||||
config.backendConfig = &backendConfig;
|
||||
|
||||
PicoDet_session = PicoDet_interpreter->createSession(config);
|
||||
|
||||
input_tensor = PicoDet_interpreter->getSessionInput(PicoDet_session, nullptr);
|
||||
}
|
||||
|
||||
PicoDet::~PicoDet() {
|
||||
PicoDet_interpreter->releaseModel();
|
||||
PicoDet_interpreter->releaseSession(PicoDet_session);
|
||||
}
|
||||
|
||||
int PicoDet::detect(cv::Mat &raw_image, std::vector<BoxInfo> &result_list,
|
||||
bool has_postprocess) {
|
||||
if (raw_image.empty()) {
|
||||
std::cout << "image is empty ,please check!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
image_h = raw_image.rows;
|
||||
image_w = raw_image.cols;
|
||||
cv::Mat image;
|
||||
cv::resize(raw_image, image, cv::Size(in_w, in_h));
|
||||
|
||||
PicoDet_interpreter->resizeTensor(input_tensor, {1, 3, in_h, in_w});
|
||||
PicoDet_interpreter->resizeSession(PicoDet_session);
|
||||
std::shared_ptr<MNN::CV::ImageProcess> pretreat(MNN::CV::ImageProcess::create(
|
||||
MNN::CV::BGR, MNN::CV::BGR, mean_vals, 3, norm_vals, 3));
|
||||
pretreat->convert(image.data, in_w, in_h, image.step[0], input_tensor);
|
||||
|
||||
auto start = chrono::steady_clock::now();
|
||||
|
||||
// run network
|
||||
PicoDet_interpreter->runSession(PicoDet_session);
|
||||
|
||||
// get output data
|
||||
std::vector<std::vector<BoxInfo>> results;
|
||||
results.resize(num_class);
|
||||
|
||||
if (has_postprocess) {
|
||||
auto bbox_out_tensor = PicoDet_interpreter->getSessionOutput(
|
||||
PicoDet_session, nms_heads_info[0].c_str());
|
||||
auto class_out_tensor = PicoDet_interpreter->getSessionOutput(
|
||||
PicoDet_session, nms_heads_info[1].c_str());
|
||||
// bbox branch
|
||||
auto tensor_bbox_host =
|
||||
new MNN::Tensor(bbox_out_tensor, MNN::Tensor::CAFFE);
|
||||
bbox_out_tensor->copyToHostTensor(tensor_bbox_host);
|
||||
auto bbox_output_shape = tensor_bbox_host->shape();
|
||||
int output_size = 1;
|
||||
for (int j = 0; j < bbox_output_shape.size(); ++j) {
|
||||
output_size *= bbox_output_shape[j];
|
||||
}
|
||||
std::cout << "output_size:" << output_size << std::endl;
|
||||
bbox_output_data_.resize(output_size);
|
||||
std::copy_n(tensor_bbox_host->host<float>(), output_size,
|
||||
bbox_output_data_.data());
|
||||
delete tensor_bbox_host;
|
||||
// class branch
|
||||
auto tensor_class_host =
|
||||
new MNN::Tensor(class_out_tensor, MNN::Tensor::CAFFE);
|
||||
class_out_tensor->copyToHostTensor(tensor_class_host);
|
||||
auto class_output_shape = tensor_class_host->shape();
|
||||
output_size = 1;
|
||||
for (int j = 0; j < class_output_shape.size(); ++j) {
|
||||
output_size *= class_output_shape[j];
|
||||
}
|
||||
std::cout << "output_size:" << output_size << std::endl;
|
||||
class_output_data_.resize(output_size);
|
||||
std::copy_n(tensor_class_host->host<float>(), output_size,
|
||||
class_output_data_.data());
|
||||
delete tensor_class_host;
|
||||
} else {
|
||||
for (const auto &head_info : non_postprocess_heads_info) {
|
||||
MNN::Tensor *tensor_scores = PicoDet_interpreter->getSessionOutput(
|
||||
PicoDet_session, head_info.cls_layer.c_str());
|
||||
MNN::Tensor *tensor_boxes = PicoDet_interpreter->getSessionOutput(
|
||||
PicoDet_session, head_info.dis_layer.c_str());
|
||||
|
||||
MNN::Tensor tensor_scores_host(tensor_scores,
|
||||
tensor_scores->getDimensionType());
|
||||
tensor_scores->copyToHostTensor(&tensor_scores_host);
|
||||
|
||||
MNN::Tensor tensor_boxes_host(tensor_boxes,
|
||||
tensor_boxes->getDimensionType());
|
||||
tensor_boxes->copyToHostTensor(&tensor_boxes_host);
|
||||
|
||||
decode_infer(&tensor_scores_host, &tensor_boxes_host, head_info.stride,
|
||||
score_threshold, results);
|
||||
}
|
||||
}
|
||||
|
||||
auto end = chrono::steady_clock::now();
|
||||
chrono::duration<double> elapsed = end - start;
|
||||
cout << "inference time:" << elapsed.count() << " s, ";
|
||||
|
||||
for (int i = 0; i < (int)results.size(); i++) {
|
||||
nms(results[i], nms_threshold);
|
||||
|
||||
for (auto box : results[i]) {
|
||||
box.x1 = box.x1 / in_w * image_w;
|
||||
box.x2 = box.x2 / in_w * image_w;
|
||||
box.y1 = box.y1 / in_h * image_h;
|
||||
box.y2 = box.y2 / in_h * image_h;
|
||||
result_list.push_back(box);
|
||||
}
|
||||
}
|
||||
cout << "detect " << result_list.size() << " objects" << endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void PicoDet::decode_infer(MNN::Tensor *cls_pred, MNN::Tensor *dis_pred,
|
||||
int stride, float threshold,
|
||||
std::vector<std::vector<BoxInfo>> &results) {
|
||||
int feature_h = ceil((float)in_h / stride);
|
||||
int feature_w = ceil((float)in_w / stride);
|
||||
|
||||
for (int idx = 0; idx < feature_h * feature_w; idx++) {
|
||||
const float *scores = cls_pred->host<float>() + (idx * num_class);
|
||||
int row = idx / feature_w;
|
||||
int col = idx % feature_w;
|
||||
float score = 0;
|
||||
int cur_label = 0;
|
||||
for (int label = 0; label < num_class; label++) {
|
||||
if (scores[label] > score) {
|
||||
score = scores[label];
|
||||
cur_label = label;
|
||||
}
|
||||
}
|
||||
if (score > threshold) {
|
||||
const float *bbox_pred =
|
||||
dis_pred->host<float>() + (idx * 4 * (reg_max + 1));
|
||||
results[cur_label].push_back(
|
||||
disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BoxInfo PicoDet::disPred2Bbox(const float *&dfl_det, int label, float score,
|
||||
int x, int y, int stride) {
|
||||
float ct_x = (x + 0.5) * stride;
|
||||
float ct_y = (y + 0.5) * stride;
|
||||
std::vector<float> dis_pred;
|
||||
dis_pred.resize(4);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float dis = 0;
|
||||
float *dis_after_sm = new float[reg_max + 1];
|
||||
activation_function_softmax(dfl_det + i * (reg_max + 1), dis_after_sm,
|
||||
reg_max + 1);
|
||||
for (int j = 0; j < reg_max + 1; j++) {
|
||||
dis += j * dis_after_sm[j];
|
||||
}
|
||||
dis *= stride;
|
||||
dis_pred[i] = dis;
|
||||
delete[] dis_after_sm;
|
||||
}
|
||||
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
|
||||
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
|
||||
float xmax = (std::min)(ct_x + dis_pred[2], (float)in_w);
|
||||
float ymax = (std::min)(ct_y + dis_pred[3], (float)in_h);
|
||||
return BoxInfo{xmin, ymin, xmax, ymax, score, label};
|
||||
}
|
||||
|
||||
void PicoDet::nms(std::vector<BoxInfo> &input_boxes, float NMS_THRESH) {
|
||||
std::sort(input_boxes.begin(), input_boxes.end(),
|
||||
[](BoxInfo a, BoxInfo b) { return a.score > b.score; });
|
||||
std::vector<float> vArea(input_boxes.size());
|
||||
for (int i = 0; i < int(input_boxes.size()); ++i) {
|
||||
vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1) *
|
||||
(input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
|
||||
}
|
||||
for (int i = 0; i < int(input_boxes.size()); ++i) {
|
||||
for (int j = i + 1; j < int(input_boxes.size());) {
|
||||
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
|
||||
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
|
||||
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
|
||||
float yy2 = (std::min)(input_boxes[i].y2, input_boxes[j].y2);
|
||||
float w = (std::max)(float(0), xx2 - xx1 + 1);
|
||||
float h = (std::max)(float(0), yy2 - yy1 + 1);
|
||||
float inter = w * h;
|
||||
float ovr = inter / (vArea[i] + vArea[j] - inter);
|
||||
if (ovr >= NMS_THRESH) {
|
||||
input_boxes.erase(input_boxes.begin() + j);
|
||||
vArea.erase(vArea.begin() + j);
|
||||
} else {
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline float fast_exp(float x) {
|
||||
union {
|
||||
uint32_t i;
|
||||
float f;
|
||||
} v{};
|
||||
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
|
||||
return v.f;
|
||||
}
|
||||
|
||||
inline float sigmoid(float x) { return 1.0f / (1.0f + fast_exp(-x)); }
|
||||
|
||||
template <typename _Tp>
|
||||
int activation_function_softmax(const _Tp *src, _Tp *dst, int length) {
|
||||
const _Tp alpha = *std::max_element(src, src + length);
|
||||
_Tp denominator{0};
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
dst[i] = fast_exp(src[i] - alpha);
|
||||
denominator += dst[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
dst[i] /= denominator;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
108
paddle_detection/deploy/third_engine/demo_mnn/picodet_mnn.hpp
Normal file
108
paddle_detection/deploy/third_engine/demo_mnn/picodet_mnn.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef __PicoDet_H__
|
||||
#define __PicoDet_H__
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Interpreter.hpp"
|
||||
|
||||
#include "ImageProcess.hpp"
|
||||
#include "MNNDefine.h"
|
||||
#include "Tensor.hpp"
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
typedef struct NonPostProcessHeadInfo_ {
|
||||
std::string cls_layer;
|
||||
std::string dis_layer;
|
||||
int stride;
|
||||
} NonPostProcessHeadInfo;
|
||||
|
||||
typedef struct BoxInfo_ {
|
||||
float x1;
|
||||
float y1;
|
||||
float x2;
|
||||
float y2;
|
||||
float score;
|
||||
int label;
|
||||
} BoxInfo;
|
||||
|
||||
class PicoDet {
|
||||
public:
|
||||
PicoDet(const std::string &mnn_path, int input_width, int input_length,
|
||||
int num_thread_ = 4, float score_threshold_ = 0.5,
|
||||
float nms_threshold_ = 0.3);
|
||||
|
||||
~PicoDet();
|
||||
|
||||
int detect(cv::Mat &img, std::vector<BoxInfo> &result_list,
|
||||
bool has_postprocess);
|
||||
|
||||
private:
|
||||
void decode_infer(MNN::Tensor *cls_pred, MNN::Tensor *dis_pred, int stride,
|
||||
float threshold,
|
||||
std::vector<std::vector<BoxInfo>> &results);
|
||||
BoxInfo disPred2Bbox(const float *&dfl_det, int label, float score, int x,
|
||||
int y, int stride);
|
||||
void nms(std::vector<BoxInfo> &input_boxes, float NMS_THRESH);
|
||||
|
||||
private:
|
||||
std::shared_ptr<MNN::Interpreter> PicoDet_interpreter;
|
||||
MNN::Session *PicoDet_session = nullptr;
|
||||
MNN::Tensor *input_tensor = nullptr;
|
||||
|
||||
int num_thread;
|
||||
int image_w;
|
||||
int image_h;
|
||||
|
||||
int in_w = 320;
|
||||
int in_h = 320;
|
||||
|
||||
float score_threshold;
|
||||
float nms_threshold;
|
||||
|
||||
const float mean_vals[3] = {103.53f, 116.28f, 123.675f};
|
||||
const float norm_vals[3] = {0.017429f, 0.017507f, 0.017125f};
|
||||
|
||||
const int num_class = 80;
|
||||
const int reg_max = 7;
|
||||
|
||||
std::vector<float> bbox_output_data_;
|
||||
std::vector<float> class_output_data_;
|
||||
|
||||
std::vector<std::string> nms_heads_info{"tmp_16", "concat_4.tmp_0"};
|
||||
// If not export post-process, will use non_postprocess_heads_info
|
||||
std::vector<NonPostProcessHeadInfo> non_postprocess_heads_info{
|
||||
// cls_pred|dis_pred|stride
|
||||
{"transpose_0.tmp_0", "transpose_1.tmp_0", 8},
|
||||
{"transpose_2.tmp_0", "transpose_3.tmp_0", 16},
|
||||
{"transpose_4.tmp_0", "transpose_5.tmp_0", 32},
|
||||
{"transpose_6.tmp_0", "transpose_7.tmp_0", 64},
|
||||
};
|
||||
};
|
||||
|
||||
template <typename _Tp>
|
||||
int activation_function_softmax(const _Tp *src, _Tp *dst, int length);
|
||||
|
||||
inline float fast_exp(float x);
|
||||
inline float sigmoid(float x);
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user