更换文档检测模型
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
cmake_minimum_required(VERSION 3.9)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
project(picodet_demo)
|
||||
|
||||
find_package(OpenMP REQUIRED)
|
||||
if(OPENMP_FOUND)
|
||||
message("OPENMP FOUND")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
|
||||
endif()
|
||||
|
||||
# find_package(OpenCV REQUIRED)
|
||||
find_package(OpenCV REQUIRED PATHS "/path/to/opencv-3.4.16_gcc8.2_ffmpeg")
|
||||
|
||||
# find_package(ncnn REQUIRED)
|
||||
find_package(ncnn REQUIRED PATHS "/path/to/ncnn/build/install/lib/cmake/ncnn")
|
||||
if(NOT TARGET ncnn)
|
||||
message(WARNING "ncnn NOT FOUND! Please set ncnn_DIR environment variable")
|
||||
else()
|
||||
message("ncnn FOUND ")
|
||||
endif()
|
||||
|
||||
include_directories(
|
||||
${OpenCV_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
|
||||
add_executable(picodet_demo main.cpp picodet.cpp)
|
||||
|
||||
target_link_libraries(
|
||||
picodet_demo
|
||||
ncnn
|
||||
${OpenCV_LIBS}
|
||||
)
|
||||
129
paddle_detection/deploy/third_engine/demo_ncnn/README.md
Normal file
129
paddle_detection/deploy/third_engine/demo_ncnn/README.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# PicoDet NCNN Demo
|
||||
|
||||
该Demo提供的预测代码是根据[Tencent's NCNN framework](https://github.com/Tencent/ncnn)推理库预测的。
|
||||
|
||||
# 第一步:编译
|
||||
## Windows
|
||||
### Step1.
|
||||
Download and Install Visual Studio from https://visualstudio.microsoft.com/vs/community/
|
||||
|
||||
### Step2.
|
||||
Download and install OpenCV from https://github.com/opencv/opencv/releases
|
||||
|
||||
为了方便,如果环境是gcc8.2 x86环境,可直接下载以下库:
|
||||
```shell
|
||||
wget https://paddledet.bj.bcebos.com/data/opencv-3.4.16_gcc8.2_ffmpeg.tar.gz
|
||||
tar -xf opencv-3.4.16_gcc8.2_ffmpeg.tar.gz
|
||||
```
|
||||
|
||||
### Step3(可选).
|
||||
Download and install Vulkan SDK from https://vulkan.lunarg.com/sdk/home
|
||||
|
||||
### Step4:编译NCNN
|
||||
|
||||
``` shell script
|
||||
git clone --recursive https://github.com/Tencent/ncnn.git
|
||||
```
|
||||
Build NCNN following this tutorial: [Build for Windows x64 using VS2017](https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-windows-x64-using-visual-studio-community-2017)
|
||||
|
||||
### Step5.
|
||||
|
||||
增加 `ncnn_DIR` = `YOUR_NCNN_PATH/build/install/lib/cmake/ncnn` 到系统变量中
|
||||
|
||||
Build project: Open x64 Native Tools Command Prompt for VS 2019 or 2017
|
||||
|
||||
``` cmd
|
||||
cd <this-folder>
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake ..
|
||||
msbuild picodet_demo.vcxproj /p:configuration=release /p:platform=x64
|
||||
```
|
||||
|
||||
## Linux
|
||||
|
||||
### Step1.
|
||||
Build and install OpenCV from https://github.com/opencv/opencv
|
||||
|
||||
### Step2(可选).
|
||||
Download Vulkan SDK from https://vulkan.lunarg.com/sdk/home
|
||||
|
||||
### Step3:编译NCNN
|
||||
Clone NCNN repository
|
||||
|
||||
``` shell script
|
||||
git clone --recursive https://github.com/Tencent/ncnn.git
|
||||
```
|
||||
|
||||
Build NCNN following this tutorial: [Build for Linux / NVIDIA Jetson / Raspberry Pi](https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-linux)
|
||||
|
||||
### Step4:编译可执行文件
|
||||
|
||||
``` shell script
|
||||
cd <this-folder>
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make
|
||||
```
|
||||
# Run demo
|
||||
|
||||
- 准备模型
|
||||
```shell
|
||||
modelName=picodet_s_320_coco_lcnet
|
||||
# 导出Inference model
|
||||
python tools/export_model.py \
|
||||
-c configs/picodet/${modelName}.yml \
|
||||
-o weights=${modelName}.pdparams \
|
||||
--output_dir=inference_model
|
||||
# 转换到ONNX
|
||||
paddle2onnx --model_dir inference_model/${modelName} \
|
||||
--model_filename model.pdmodel \
|
||||
--params_filename model.pdiparams \
|
||||
--opset_version 11 \
|
||||
--save_file ${modelName}.onnx
|
||||
# 简化模型
|
||||
python -m onnxsim ${modelName}.onnx ${modelName}_processed.onnx
|
||||
# 将模型转换至NCNN格式
|
||||
Run onnx2ncnn in ncnn tools to generate ncnn .param and .bin file.
|
||||
```
|
||||
转NCNN模型可以利用在线转换工具 [https://convertmodel.com](https://convertmodel.com/)
|
||||
|
||||
为了快速测试,可直接下载:[picodet_s_320_coco_lcnet-opt.bin](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_coco_lcnet-opt.bin)/ [picodet_s_320_coco_lcnet-opt.param](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_coco_lcnet-opt.param)(不带后处理)。
|
||||
|
||||
**注意:**由于带后处理后,NCNN预测会出NAN,暂时使用不带后处理Demo即可,带后处理的Demo正在升级中,很快发布。
|
||||
|
||||
|
||||
## 开始运行
|
||||
|
||||
首先新建预测结果存放目录:
|
||||
```shell
|
||||
cp -r ../demo_onnxruntime/imgs .
|
||||
cd build
|
||||
mkdir ../results
|
||||
```
|
||||
|
||||
- 预测一张图片
|
||||
``` shell
|
||||
./picodet_demo 0 ../picodet_s_320_coco_lcnet.bin ../picodet_s_320_coco_lcnet.param 320 320 ../imgs/dog.jpg 0
|
||||
```
|
||||
具体参数解析可参考`main.cpp`。
|
||||
|
||||
-测试速度Benchmark
|
||||
|
||||
``` shell
|
||||
./picodet_demo 1 ../picodet_s_320_lcnet.bin ../picodet_s_320_lcnet.param 320 320 0
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
- 预测结果精度不对:
|
||||
请先确认模型输入shape是否对齐,并且模型输出name是否对齐,不带后处理的PicoDet增强版模型输出name如下:
|
||||
```shell
|
||||
# 分类分支 | 检测分支
|
||||
{"transpose_0.tmp_0", "transpose_1.tmp_0"},
|
||||
{"transpose_2.tmp_0", "transpose_3.tmp_0"},
|
||||
{"transpose_4.tmp_0", "transpose_5.tmp_0"},
|
||||
{"transpose_6.tmp_0", "transpose_7.tmp_0"},
|
||||
```
|
||||
可使用[netron](https://netron.app)查看具体name,并修改`picodet_mnn.hpp`中相应`non_postprocess_heads_info`数组。
|
||||
210
paddle_detection/deploy/third_engine/demo_ncnn/main.cpp
Normal file
210
paddle_detection/deploy/third_engine/demo_ncnn/main.cpp
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_ncnn
|
||||
|
||||
#include "picodet.h"
|
||||
#include <benchmark.h>
|
||||
#include <iostream>
|
||||
#include <net.h>
|
||||
#include <opencv2/core/core.hpp>
|
||||
#include <opencv2/highgui/highgui.hpp>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
|
||||
#define __SAVE_RESULT__ // if defined save drawed results to ../results, else
|
||||
// show it in windows
|
||||
struct object_rect {
|
||||
int x;
|
||||
int y;
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
std::vector<int> GenerateColorMap(int num_class) {
|
||||
auto colormap = std::vector<int>(3 * num_class, 0);
|
||||
for (int i = 0; i < num_class; ++i) {
|
||||
int j = 0;
|
||||
int lab = i;
|
||||
while (lab) {
|
||||
colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j));
|
||||
colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
|
||||
colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
|
||||
++j;
|
||||
lab >>= 3;
|
||||
}
|
||||
}
|
||||
return colormap;
|
||||
}
|
||||
|
||||
void draw_bboxes(const cv::Mat &im, const std::vector<BoxInfo> &bboxes,
|
||||
std::string save_path = "None") {
|
||||
static const char *class_names[] = {
|
||||
"person", "bicycle", "car",
|
||||
"motorcycle", "airplane", "bus",
|
||||
"train", "truck", "boat",
|
||||
"traffic light", "fire hydrant", "stop sign",
|
||||
"parking meter", "bench", "bird",
|
||||
"cat", "dog", "horse",
|
||||
"sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe",
|
||||
"backpack", "umbrella", "handbag",
|
||||
"tie", "suitcase", "frisbee",
|
||||
"skis", "snowboard", "sports ball",
|
||||
"kite", "baseball bat", "baseball glove",
|
||||
"skateboard", "surfboard", "tennis racket",
|
||||
"bottle", "wine glass", "cup",
|
||||
"fork", "knife", "spoon",
|
||||
"bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli",
|
||||
"carrot", "hot dog", "pizza",
|
||||
"donut", "cake", "chair",
|
||||
"couch", "potted plant", "bed",
|
||||
"dining table", "toilet", "tv",
|
||||
"laptop", "mouse", "remote",
|
||||
"keyboard", "cell phone", "microwave",
|
||||
"oven", "toaster", "sink",
|
||||
"refrigerator", "book", "clock",
|
||||
"vase", "scissors", "teddy bear",
|
||||
"hair drier", "toothbrush"};
|
||||
|
||||
cv::Mat image = im.clone();
|
||||
int src_w = image.cols;
|
||||
int src_h = image.rows;
|
||||
int thickness = 2;
|
||||
auto colormap = GenerateColorMap(sizeof(class_names));
|
||||
|
||||
for (size_t i = 0; i < bboxes.size(); i++) {
|
||||
const BoxInfo &bbox = bboxes[i];
|
||||
std::cout << bbox.x1 << ". " << bbox.y1 << ". " << bbox.x2 << ". "
|
||||
<< bbox.y2 << ". " << std::endl;
|
||||
int c1 = colormap[3 * bbox.label + 0];
|
||||
int c2 = colormap[3 * bbox.label + 1];
|
||||
int c3 = colormap[3 * bbox.label + 2];
|
||||
cv::Scalar color = cv::Scalar(c1, c2, c3);
|
||||
// cv::Scalar color = cv::Scalar(0, 0, 255);
|
||||
cv::rectangle(image, cv::Rect(cv::Point(bbox.x1, bbox.y1),
|
||||
cv::Point(bbox.x2, bbox.y2)),
|
||||
color, 1);
|
||||
|
||||
char text[256];
|
||||
sprintf(text, "%s %.1f%%", class_names[bbox.label], bbox.score * 100);
|
||||
|
||||
int baseLine = 0;
|
||||
cv::Size label_size =
|
||||
cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
|
||||
|
||||
int x = bbox.x1;
|
||||
int y = bbox.y1 - label_size.height - baseLine;
|
||||
if (y < 0)
|
||||
y = 0;
|
||||
if (x + label_size.width > image.cols)
|
||||
x = image.cols - label_size.width;
|
||||
|
||||
cv::rectangle(image, cv::Rect(cv::Point(x, y),
|
||||
cv::Size(label_size.width,
|
||||
label_size.height + baseLine)),
|
||||
color, -1);
|
||||
|
||||
cv::putText(image, text, cv::Point(x, y + label_size.height),
|
||||
cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(255, 255, 255), 1);
|
||||
}
|
||||
|
||||
if (save_path == "None") {
|
||||
cv::imshow("image", image);
|
||||
} else {
|
||||
cv::imwrite(save_path, image);
|
||||
std::cout << "Result save in: " << save_path << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int image_demo(PicoDet &detector, const char *imagepath,
|
||||
int has_postprocess = 0) {
|
||||
std::vector<cv::String> filenames;
|
||||
cv::glob(imagepath, filenames, false);
|
||||
bool is_postprocess = has_postprocess > 0 ? true : false;
|
||||
for (auto img_name : filenames) {
|
||||
cv::Mat image = cv::imread(img_name, cv::IMREAD_COLOR);
|
||||
if (image.empty()) {
|
||||
fprintf(stderr, "cv::imread %s failed\n", img_name.c_str());
|
||||
return -1;
|
||||
}
|
||||
std::vector<BoxInfo> results;
|
||||
detector.detect(image, results, is_postprocess);
|
||||
std::cout << "detect done." << std::endl;
|
||||
|
||||
#ifdef __SAVE_RESULT__
|
||||
std::string save_path = img_name;
|
||||
draw_bboxes(image, results, save_path.replace(3, 4, "results"));
|
||||
#else
|
||||
draw_bboxes(image, results);
|
||||
cv::waitKey(0);
|
||||
#endif
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int benchmark(PicoDet &detector, int width, int height,
|
||||
int has_postprocess = 0) {
|
||||
int loop_num = 100;
|
||||
int warm_up = 8;
|
||||
|
||||
double time_min = DBL_MAX;
|
||||
double time_max = -DBL_MAX;
|
||||
double time_avg = 0;
|
||||
cv::Mat image(width, height, CV_8UC3, cv::Scalar(1, 1, 1));
|
||||
bool is_postprocess = has_postprocess > 0 ? true : false;
|
||||
for (int i = 0; i < warm_up + loop_num; i++) {
|
||||
double start = ncnn::get_current_time();
|
||||
std::vector<BoxInfo> results;
|
||||
detector.detect(image, results, is_postprocess);
|
||||
double end = ncnn::get_current_time();
|
||||
|
||||
double time = end - start;
|
||||
if (i >= warm_up) {
|
||||
time_min = (std::min)(time_min, time);
|
||||
time_max = (std::max)(time_max, time);
|
||||
time_avg += time;
|
||||
}
|
||||
}
|
||||
time_avg /= loop_num;
|
||||
fprintf(stderr, "%20s min = %7.2f max = %7.2f avg = %7.2f\n", "picodet",
|
||||
time_min, time_max, time_avg);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int mode = atoi(argv[1]);
|
||||
char *bin_model_path = argv[2];
|
||||
char *param_model_path = argv[3];
|
||||
int height = 320;
|
||||
int width = 320;
|
||||
if (argc == 5) {
|
||||
height = atoi(argv[4]);
|
||||
width = atoi(argv[5]);
|
||||
}
|
||||
PicoDet detector =
|
||||
PicoDet(param_model_path, bin_model_path, width, height, true, 0.45, 0.3);
|
||||
if (mode == 1) {
|
||||
|
||||
benchmark(detector, width, height, atoi(argv[6]));
|
||||
} else {
|
||||
if (argc != 6) {
|
||||
std::cout << "Must set image file, such as ./picodet_demo 0 "
|
||||
"../picodet_s_320_lcnet.bin ../picodet_s_320_lcnet.param "
|
||||
"320 320 img.jpg"
|
||||
<< std::endl;
|
||||
}
|
||||
const char *images = argv[6];
|
||||
image_demo(detector, images, atoi(argv[7]));
|
||||
}
|
||||
}
|
||||
236
paddle_detection/deploy/third_engine/demo_ncnn/picodet.cpp
Normal file
236
paddle_detection/deploy/third_engine/demo_ncnn/picodet.cpp
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_ncnn
|
||||
|
||||
#include "picodet.h"
|
||||
#include <benchmark.h>
|
||||
#include <iostream>
|
||||
|
||||
inline float fast_exp(float x) {
|
||||
union {
|
||||
uint32_t i;
|
||||
float f;
|
||||
} v{};
|
||||
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
|
||||
return v.f;
|
||||
}
|
||||
|
||||
inline float sigmoid(float x) { return 1.0f / (1.0f + fast_exp(-x)); }
|
||||
|
||||
template <typename _Tp>
|
||||
int activation_function_softmax(const _Tp *src, _Tp *dst, int length) {
|
||||
const _Tp alpha = *std::max_element(src, src + length);
|
||||
_Tp denominator{0};
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
dst[i] = fast_exp(src[i] - alpha);
|
||||
denominator += dst[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
dst[i] /= denominator;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool PicoDet::hasGPU = false;
|
||||
PicoDet *PicoDet::detector = nullptr;
|
||||
|
||||
PicoDet::PicoDet(const char *param, const char *bin, int input_width,
|
||||
int input_hight, bool useGPU, float score_threshold_ = 0.5,
|
||||
float nms_threshold_ = 0.3) {
|
||||
this->Net = new ncnn::Net();
|
||||
#if NCNN_VULKAN
|
||||
this->hasGPU = ncnn::get_gpu_count() > 0;
|
||||
#endif
|
||||
this->Net->opt.use_vulkan_compute = this->hasGPU && useGPU;
|
||||
this->Net->opt.use_fp16_arithmetic = true;
|
||||
this->Net->load_param(param);
|
||||
this->Net->load_model(bin);
|
||||
this->in_w = input_width;
|
||||
this->in_h = input_hight;
|
||||
this->score_threshold = score_threshold_;
|
||||
this->nms_threshold = nms_threshold_;
|
||||
}
|
||||
|
||||
PicoDet::~PicoDet() { delete this->Net; }
|
||||
|
||||
void PicoDet::preprocess(cv::Mat &image, ncnn::Mat &in) {
|
||||
// cv::resize(image, image, cv::Size(this->in_w, this->in_h), 0.f, 0.f);
|
||||
int img_w = image.cols;
|
||||
int img_h = image.rows;
|
||||
in = ncnn::Mat::from_pixels_resize(image.data, ncnn::Mat::PIXEL_BGR, img_w,
|
||||
img_h, this->in_w, this->in_h);
|
||||
const float mean_vals[3] = {103.53f, 116.28f, 123.675f};
|
||||
const float norm_vals[3] = {0.017429f, 0.017507f, 0.017125f};
|
||||
in.substract_mean_normalize(mean_vals, norm_vals);
|
||||
}
|
||||
|
||||
int PicoDet::detect(cv::Mat image, std::vector<BoxInfo> &result_list,
|
||||
bool has_postprocess) {
|
||||
|
||||
ncnn::Mat input;
|
||||
preprocess(image, input);
|
||||
auto ex = this->Net->create_extractor();
|
||||
ex.set_light_mode(false);
|
||||
ex.set_num_threads(4);
|
||||
#if NCNN_VULKAN
|
||||
ex.set_vulkan_compute(this->hasGPU);
|
||||
#endif
|
||||
ex.input("image", input); // picodet
|
||||
|
||||
this->image_h = image.rows;
|
||||
this->image_w = image.cols;
|
||||
|
||||
std::vector<std::vector<BoxInfo>> results;
|
||||
results.resize(this->num_class);
|
||||
|
||||
if (has_postprocess) {
|
||||
ncnn::Mat dis_pred;
|
||||
ncnn::Mat cls_pred;
|
||||
ex.extract(this->nms_heads_info[0].c_str(), dis_pred);
|
||||
ex.extract(this->nms_heads_info[1].c_str(), cls_pred);
|
||||
std::cout << dis_pred.h << " " << dis_pred.w << std::endl;
|
||||
std::cout << cls_pred.h << " " << cls_pred.w << std::endl;
|
||||
this->nms_boxes(cls_pred, dis_pred, this->score_threshold, results);
|
||||
} else {
|
||||
for (const auto &head_info : this->non_postprocess_heads_info) {
|
||||
ncnn::Mat dis_pred;
|
||||
ncnn::Mat cls_pred;
|
||||
ex.extract(head_info.dis_layer.c_str(), dis_pred);
|
||||
ex.extract(head_info.cls_layer.c_str(), cls_pred);
|
||||
this->decode_infer(cls_pred, dis_pred, head_info.stride,
|
||||
this->score_threshold, results);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)results.size(); i++) {
|
||||
this->nms(results[i], this->nms_threshold);
|
||||
|
||||
for (auto box : results[i]) {
|
||||
box.x1 = box.x1 / this->in_w * this->image_w;
|
||||
box.x2 = box.x2 / this->in_w * this->image_w;
|
||||
box.y1 = box.y1 / this->in_h * this->image_h;
|
||||
box.y2 = box.y2 / this->in_h * this->image_h;
|
||||
result_list.push_back(box);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void PicoDet::nms_boxes(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred,
|
||||
float score_threshold,
|
||||
std::vector<std::vector<BoxInfo>> &result_list) {
|
||||
BoxInfo bbox;
|
||||
int i, j;
|
||||
for (i = 0; i < dis_pred.h; i++) {
|
||||
bbox.x1 = dis_pred.row(i)[0];
|
||||
bbox.y1 = dis_pred.row(i)[1];
|
||||
bbox.x2 = dis_pred.row(i)[2];
|
||||
bbox.y2 = dis_pred.row(i)[3];
|
||||
const float *scores = cls_pred.row(i);
|
||||
float score = 0;
|
||||
int cur_label = 0;
|
||||
for (int label = 0; label < this->num_class; label++) {
|
||||
float score_ = cls_pred.row(label)[i];
|
||||
if (score_ > score) {
|
||||
score = score_;
|
||||
cur_label = label;
|
||||
}
|
||||
}
|
||||
bbox.score = score;
|
||||
bbox.label = cur_label;
|
||||
result_list[cur_label].push_back(bbox);
|
||||
}
|
||||
}
|
||||
|
||||
void PicoDet::decode_infer(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred, int stride,
|
||||
float threshold,
|
||||
std::vector<std::vector<BoxInfo>> &results) {
|
||||
int feature_h = ceil((float)this->in_w / stride);
|
||||
int feature_w = ceil((float)this->in_h / stride);
|
||||
|
||||
for (int idx = 0; idx < feature_h * feature_w; idx++) {
|
||||
const float *scores = cls_pred.row(idx);
|
||||
int row = idx / feature_w;
|
||||
int col = idx % feature_w;
|
||||
float score = 0;
|
||||
int cur_label = 0;
|
||||
for (int label = 0; label < this->num_class; label++) {
|
||||
if (scores[label] > score) {
|
||||
score = scores[label];
|
||||
cur_label = label;
|
||||
}
|
||||
}
|
||||
if (score > threshold) {
|
||||
const float *bbox_pred = dis_pred.row(idx);
|
||||
results[cur_label].push_back(
|
||||
this->disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BoxInfo PicoDet::disPred2Bbox(const float *&dfl_det, int label, float score,
|
||||
int x, int y, int stride) {
|
||||
float ct_x = (x + 0.5) * stride;
|
||||
float ct_y = (y + 0.5) * stride;
|
||||
std::vector<float> dis_pred;
|
||||
dis_pred.resize(4);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float dis = 0;
|
||||
float *dis_after_sm = new float[this->reg_max + 1];
|
||||
activation_function_softmax(dfl_det + i * (this->reg_max + 1), dis_after_sm,
|
||||
this->reg_max + 1);
|
||||
for (int j = 0; j < this->reg_max + 1; j++) {
|
||||
dis += j * dis_after_sm[j];
|
||||
}
|
||||
dis *= stride;
|
||||
dis_pred[i] = dis;
|
||||
delete[] dis_after_sm;
|
||||
}
|
||||
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
|
||||
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
|
||||
float xmax = (std::min)(ct_x + dis_pred[2], (float)this->in_w);
|
||||
float ymax = (std::min)(ct_y + dis_pred[3], (float)this->in_w);
|
||||
return BoxInfo{xmin, ymin, xmax, ymax, score, label};
|
||||
}
|
||||
|
||||
void PicoDet::nms(std::vector<BoxInfo> &input_boxes, float NMS_THRESH) {
|
||||
std::sort(input_boxes.begin(), input_boxes.end(),
|
||||
[](BoxInfo a, BoxInfo b) { return a.score > b.score; });
|
||||
std::vector<float> vArea(input_boxes.size());
|
||||
for (int i = 0; i < int(input_boxes.size()); ++i) {
|
||||
vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1) *
|
||||
(input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
|
||||
}
|
||||
for (int i = 0; i < int(input_boxes.size()); ++i) {
|
||||
for (int j = i + 1; j < int(input_boxes.size());) {
|
||||
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
|
||||
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
|
||||
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
|
||||
float yy2 = (std::min)(input_boxes[i].y2, input_boxes[j].y2);
|
||||
float w = (std::max)(float(0), xx2 - xx1 + 1);
|
||||
float h = (std::max)(float(0), yy2 - yy1 + 1);
|
||||
float inter = w * h;
|
||||
float ovr = inter / (vArea[i] + vArea[j] - inter);
|
||||
if (ovr >= NMS_THRESH) {
|
||||
input_boxes.erase(input_boxes.begin() + j);
|
||||
vArea.erase(vArea.begin() + j);
|
||||
} else {
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
87
paddle_detection/deploy/third_engine/demo_ncnn/picodet.h
Normal file
87
paddle_detection/deploy/third_engine/demo_ncnn/picodet.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_ncnn
|
||||
|
||||
#ifndef PICODET_H
|
||||
#define PICODET_H
|
||||
|
||||
#include <net.h>
|
||||
#include <opencv2/core/core.hpp>
|
||||
|
||||
typedef struct NonPostProcessHeadInfo {
|
||||
std::string cls_layer;
|
||||
std::string dis_layer;
|
||||
int stride;
|
||||
} NonPostProcessHeadInfo;
|
||||
|
||||
typedef struct BoxInfo {
|
||||
float x1;
|
||||
float y1;
|
||||
float x2;
|
||||
float y2;
|
||||
float score;
|
||||
int label;
|
||||
} BoxInfo;
|
||||
|
||||
class PicoDet {
|
||||
public:
|
||||
PicoDet(const char *param, const char *bin, int input_width, int input_hight,
|
||||
bool useGPU, float score_threshold_, float nms_threshold_);
|
||||
|
||||
~PicoDet();
|
||||
|
||||
static PicoDet *detector;
|
||||
ncnn::Net *Net;
|
||||
static bool hasGPU;
|
||||
|
||||
int detect(cv::Mat image, std::vector<BoxInfo> &result_list,
|
||||
bool has_postprocess);
|
||||
|
||||
private:
|
||||
void preprocess(cv::Mat &image, ncnn::Mat &in);
|
||||
void decode_infer(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred, int stride,
|
||||
float threshold,
|
||||
std::vector<std::vector<BoxInfo>> &results);
|
||||
BoxInfo disPred2Bbox(const float *&dfl_det, int label, float score, int x,
|
||||
int y, int stride);
|
||||
static void nms(std::vector<BoxInfo> &result, float nms_threshold);
|
||||
void nms_boxes(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred,
|
||||
float score_threshold,
|
||||
std::vector<std::vector<BoxInfo>> &result_list);
|
||||
|
||||
int image_w;
|
||||
int image_h;
|
||||
int in_w = 320;
|
||||
int in_h = 320;
|
||||
int num_class = 80;
|
||||
int reg_max = 7;
|
||||
|
||||
float score_threshold;
|
||||
float nms_threshold;
|
||||
|
||||
std::vector<float> bbox_output_data_;
|
||||
std::vector<float> class_output_data_;
|
||||
|
||||
std::vector<std::string> nms_heads_info{"tmp_16", "concat_4.tmp_0"};
|
||||
// If not export post-process, will use non_postprocess_heads_info
|
||||
std::vector<NonPostProcessHeadInfo> non_postprocess_heads_info{
|
||||
// cls_pred|dis_pred|stride
|
||||
{"transpose_0.tmp_0", "transpose_1.tmp_0", 8},
|
||||
{"transpose_2.tmp_0", "transpose_3.tmp_0", 16},
|
||||
{"transpose_4.tmp_0", "transpose_5.tmp_0", 32},
|
||||
{"transpose_6.tmp_0", "transpose_7.tmp_0", 64},
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user