更换文档检测模型

This commit is contained in:
2024-08-27 14:42:45 +08:00
parent aea6f19951
commit 1514e09c40
2072 changed files with 254336 additions and 4967 deletions

View File

@@ -0,0 +1,38 @@
cmake_minimum_required(VERSION 3.9)
set(CMAKE_CXX_STANDARD 17)
project(picodet_demo)
find_package(OpenMP REQUIRED)
if(OPENMP_FOUND)
message("OPENMP FOUND")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
# find_package(OpenCV REQUIRED)
find_package(OpenCV REQUIRED PATHS "/path/to/opencv-3.4.16_gcc8.2_ffmpeg")
# find_package(ncnn REQUIRED)
find_package(ncnn REQUIRED PATHS "/path/to/ncnn/build/install/lib/cmake/ncnn")
if(NOT TARGET ncnn)
message(WARNING "ncnn NOT FOUND! Please set ncnn_DIR environment variable")
else()
message("ncnn FOUND ")
endif()
include_directories(
${OpenCV_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
)
add_executable(picodet_demo main.cpp picodet.cpp)
target_link_libraries(
picodet_demo
ncnn
${OpenCV_LIBS}
)

View File

@@ -0,0 +1,129 @@
# PicoDet NCNN Demo
该Demo提供的预测代码是根据[Tencent's NCNN framework](https://github.com/Tencent/ncnn)推理库预测的。
# 第一步:编译
## Windows
### Step1.
Download and Install Visual Studio from https://visualstudio.microsoft.com/vs/community/
### Step2.
Download and install OpenCV from https://github.com/opencv/opencv/releases
为了方便如果环境是gcc8.2 x86环境可直接下载以下库
```shell
wget https://paddledet.bj.bcebos.com/data/opencv-3.4.16_gcc8.2_ffmpeg.tar.gz
tar -xf opencv-3.4.16_gcc8.2_ffmpeg.tar.gz
```
### Step3(可选).
Download and install Vulkan SDK from https://vulkan.lunarg.com/sdk/home
### Step4编译NCNN
``` shell script
git clone --recursive https://github.com/Tencent/ncnn.git
```
Build NCNN following this tutorial: [Build for Windows x64 using VS2017](https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-windows-x64-using-visual-studio-community-2017)
### Step5.
增加 `ncnn_DIR` = `YOUR_NCNN_PATH/build/install/lib/cmake/ncnn` 到系统变量中
Build project: Open x64 Native Tools Command Prompt for VS 2019 or 2017
``` cmd
cd <this-folder>
mkdir -p build
cd build
cmake ..
msbuild picodet_demo.vcxproj /p:configuration=release /p:platform=x64
```
## Linux
### Step1.
Build and install OpenCV from https://github.com/opencv/opencv
### Step2(可选).
Download Vulkan SDK from https://vulkan.lunarg.com/sdk/home
### Step3编译NCNN
Clone NCNN repository
``` shell script
git clone --recursive https://github.com/Tencent/ncnn.git
```
Build NCNN following this tutorial: [Build for Linux / NVIDIA Jetson / Raspberry Pi](https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-linux)
### Step4编译可执行文件
``` shell script
cd <this-folder>
mkdir build
cd build
cmake ..
make
```
# Run demo
- 准备模型
```shell
modelName=picodet_s_320_coco_lcnet
# 导出Inference model
python tools/export_model.py \
-c configs/picodet/${modelName}.yml \
-o weights=${modelName}.pdparams \
--output_dir=inference_model
# 转换到ONNX
paddle2onnx --model_dir inference_model/${modelName} \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--opset_version 11 \
--save_file ${modelName}.onnx
# 简化模型
python -m onnxsim ${modelName}.onnx ${modelName}_processed.onnx
# 将模型转换至NCNN格式
Run onnx2ncnn in ncnn tools to generate ncnn .param and .bin file.
```
转NCNN模型可以利用在线转换工具 [https://convertmodel.com](https://convertmodel.com/)
为了快速测试,可直接下载:[picodet_s_320_coco_lcnet-opt.bin](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_coco_lcnet-opt.bin)/ [picodet_s_320_coco_lcnet-opt.param](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_coco_lcnet-opt.param)(不带后处理)。
**注意:**由于带后处理后NCNN预测会出NAN暂时使用不带后处理Demo即可带后处理的Demo正在升级中很快发布。
## 开始运行
首先新建预测结果存放目录:
```shell
cp -r ../demo_onnxruntime/imgs .
cd build
mkdir ../results
```
- 预测一张图片
``` shell
./picodet_demo 0 ../picodet_s_320_coco_lcnet.bin ../picodet_s_320_coco_lcnet.param 320 320 ../imgs/dog.jpg 0
```
具体参数解析可参考`main.cpp`。
-测试速度Benchmark
``` shell
./picodet_demo 1 ../picodet_s_320_lcnet.bin ../picodet_s_320_lcnet.param 320 320 0
```
## FAQ
- 预测结果精度不对:
请先确认模型输入shape是否对齐并且模型输出name是否对齐不带后处理的PicoDet增强版模型输出name如下
```shell
# 分类分支 | 检测分支
{"transpose_0.tmp_0", "transpose_1.tmp_0"},
{"transpose_2.tmp_0", "transpose_3.tmp_0"},
{"transpose_4.tmp_0", "transpose_5.tmp_0"},
{"transpose_6.tmp_0", "transpose_7.tmp_0"},
```
可使用[netron](https://netron.app)查看具体name并修改`picodet_mnn.hpp`中相应`non_postprocess_heads_info`数组。

View File

@@ -0,0 +1,210 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_ncnn
#include "picodet.h"
#include <benchmark.h>
#include <iostream>
#include <net.h>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#define __SAVE_RESULT__ // if defined save drawed results to ../results, else
// show it in windows
struct object_rect {
int x;
int y;
int width;
int height;
};
std::vector<int> GenerateColorMap(int num_class) {
auto colormap = std::vector<int>(3 * num_class, 0);
for (int i = 0; i < num_class; ++i) {
int j = 0;
int lab = i;
while (lab) {
colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j));
colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
++j;
lab >>= 3;
}
}
return colormap;
}
void draw_bboxes(const cv::Mat &im, const std::vector<BoxInfo> &bboxes,
std::string save_path = "None") {
static const char *class_names[] = {
"person", "bicycle", "car",
"motorcycle", "airplane", "bus",
"train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign",
"parking meter", "bench", "bird",
"cat", "dog", "horse",
"sheep", "cow", "elephant",
"bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag",
"tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball",
"kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup",
"fork", "knife", "spoon",
"bowl", "banana", "apple",
"sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza",
"donut", "cake", "chair",
"couch", "potted plant", "bed",
"dining table", "toilet", "tv",
"laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave",
"oven", "toaster", "sink",
"refrigerator", "book", "clock",
"vase", "scissors", "teddy bear",
"hair drier", "toothbrush"};
cv::Mat image = im.clone();
int src_w = image.cols;
int src_h = image.rows;
int thickness = 2;
auto colormap = GenerateColorMap(sizeof(class_names));
for (size_t i = 0; i < bboxes.size(); i++) {
const BoxInfo &bbox = bboxes[i];
std::cout << bbox.x1 << ". " << bbox.y1 << ". " << bbox.x2 << ". "
<< bbox.y2 << ". " << std::endl;
int c1 = colormap[3 * bbox.label + 0];
int c2 = colormap[3 * bbox.label + 1];
int c3 = colormap[3 * bbox.label + 2];
cv::Scalar color = cv::Scalar(c1, c2, c3);
// cv::Scalar color = cv::Scalar(0, 0, 255);
cv::rectangle(image, cv::Rect(cv::Point(bbox.x1, bbox.y1),
cv::Point(bbox.x2, bbox.y2)),
color, 1);
char text[256];
sprintf(text, "%s %.1f%%", class_names[bbox.label], bbox.score * 100);
int baseLine = 0;
cv::Size label_size =
cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
int x = bbox.x1;
int y = bbox.y1 - label_size.height - baseLine;
if (y < 0)
y = 0;
if (x + label_size.width > image.cols)
x = image.cols - label_size.width;
cv::rectangle(image, cv::Rect(cv::Point(x, y),
cv::Size(label_size.width,
label_size.height + baseLine)),
color, -1);
cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(255, 255, 255), 1);
}
if (save_path == "None") {
cv::imshow("image", image);
} else {
cv::imwrite(save_path, image);
std::cout << "Result save in: " << save_path << std::endl;
}
}
int image_demo(PicoDet &detector, const char *imagepath,
int has_postprocess = 0) {
std::vector<cv::String> filenames;
cv::glob(imagepath, filenames, false);
bool is_postprocess = has_postprocess > 0 ? true : false;
for (auto img_name : filenames) {
cv::Mat image = cv::imread(img_name, cv::IMREAD_COLOR);
if (image.empty()) {
fprintf(stderr, "cv::imread %s failed\n", img_name.c_str());
return -1;
}
std::vector<BoxInfo> results;
detector.detect(image, results, is_postprocess);
std::cout << "detect done." << std::endl;
#ifdef __SAVE_RESULT__
std::string save_path = img_name;
draw_bboxes(image, results, save_path.replace(3, 4, "results"));
#else
draw_bboxes(image, results);
cv::waitKey(0);
#endif
}
return 0;
}
int benchmark(PicoDet &detector, int width, int height,
int has_postprocess = 0) {
int loop_num = 100;
int warm_up = 8;
double time_min = DBL_MAX;
double time_max = -DBL_MAX;
double time_avg = 0;
cv::Mat image(width, height, CV_8UC3, cv::Scalar(1, 1, 1));
bool is_postprocess = has_postprocess > 0 ? true : false;
for (int i = 0; i < warm_up + loop_num; i++) {
double start = ncnn::get_current_time();
std::vector<BoxInfo> results;
detector.detect(image, results, is_postprocess);
double end = ncnn::get_current_time();
double time = end - start;
if (i >= warm_up) {
time_min = (std::min)(time_min, time);
time_max = (std::max)(time_max, time);
time_avg += time;
}
}
time_avg /= loop_num;
fprintf(stderr, "%20s min = %7.2f max = %7.2f avg = %7.2f\n", "picodet",
time_min, time_max, time_avg);
return 0;
}
int main(int argc, char **argv) {
int mode = atoi(argv[1]);
char *bin_model_path = argv[2];
char *param_model_path = argv[3];
int height = 320;
int width = 320;
if (argc == 5) {
height = atoi(argv[4]);
width = atoi(argv[5]);
}
PicoDet detector =
PicoDet(param_model_path, bin_model_path, width, height, true, 0.45, 0.3);
if (mode == 1) {
benchmark(detector, width, height, atoi(argv[6]));
} else {
if (argc != 6) {
std::cout << "Must set image file, such as ./picodet_demo 0 "
"../picodet_s_320_lcnet.bin ../picodet_s_320_lcnet.param "
"320 320 img.jpg"
<< std::endl;
}
const char *images = argv[6];
image_demo(detector, images, atoi(argv[7]));
}
}

View File

@@ -0,0 +1,236 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_ncnn
#include "picodet.h"
#include <benchmark.h>
#include <iostream>
inline float fast_exp(float x) {
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
inline float sigmoid(float x) { return 1.0f / (1.0f + fast_exp(-x)); }
template <typename _Tp>
int activation_function_softmax(const _Tp *src, _Tp *dst, int length) {
const _Tp alpha = *std::max_element(src, src + length);
_Tp denominator{0};
for (int i = 0; i < length; ++i) {
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i) {
dst[i] /= denominator;
}
return 0;
}
bool PicoDet::hasGPU = false;
PicoDet *PicoDet::detector = nullptr;
PicoDet::PicoDet(const char *param, const char *bin, int input_width,
int input_hight, bool useGPU, float score_threshold_ = 0.5,
float nms_threshold_ = 0.3) {
this->Net = new ncnn::Net();
#if NCNN_VULKAN
this->hasGPU = ncnn::get_gpu_count() > 0;
#endif
this->Net->opt.use_vulkan_compute = this->hasGPU && useGPU;
this->Net->opt.use_fp16_arithmetic = true;
this->Net->load_param(param);
this->Net->load_model(bin);
this->in_w = input_width;
this->in_h = input_hight;
this->score_threshold = score_threshold_;
this->nms_threshold = nms_threshold_;
}
PicoDet::~PicoDet() { delete this->Net; }
void PicoDet::preprocess(cv::Mat &image, ncnn::Mat &in) {
// cv::resize(image, image, cv::Size(this->in_w, this->in_h), 0.f, 0.f);
int img_w = image.cols;
int img_h = image.rows;
in = ncnn::Mat::from_pixels_resize(image.data, ncnn::Mat::PIXEL_BGR, img_w,
img_h, this->in_w, this->in_h);
const float mean_vals[3] = {103.53f, 116.28f, 123.675f};
const float norm_vals[3] = {0.017429f, 0.017507f, 0.017125f};
in.substract_mean_normalize(mean_vals, norm_vals);
}
int PicoDet::detect(cv::Mat image, std::vector<BoxInfo> &result_list,
bool has_postprocess) {
ncnn::Mat input;
preprocess(image, input);
auto ex = this->Net->create_extractor();
ex.set_light_mode(false);
ex.set_num_threads(4);
#if NCNN_VULKAN
ex.set_vulkan_compute(this->hasGPU);
#endif
ex.input("image", input); // picodet
this->image_h = image.rows;
this->image_w = image.cols;
std::vector<std::vector<BoxInfo>> results;
results.resize(this->num_class);
if (has_postprocess) {
ncnn::Mat dis_pred;
ncnn::Mat cls_pred;
ex.extract(this->nms_heads_info[0].c_str(), dis_pred);
ex.extract(this->nms_heads_info[1].c_str(), cls_pred);
std::cout << dis_pred.h << " " << dis_pred.w << std::endl;
std::cout << cls_pred.h << " " << cls_pred.w << std::endl;
this->nms_boxes(cls_pred, dis_pred, this->score_threshold, results);
} else {
for (const auto &head_info : this->non_postprocess_heads_info) {
ncnn::Mat dis_pred;
ncnn::Mat cls_pred;
ex.extract(head_info.dis_layer.c_str(), dis_pred);
ex.extract(head_info.cls_layer.c_str(), cls_pred);
this->decode_infer(cls_pred, dis_pred, head_info.stride,
this->score_threshold, results);
}
}
for (int i = 0; i < (int)results.size(); i++) {
this->nms(results[i], this->nms_threshold);
for (auto box : results[i]) {
box.x1 = box.x1 / this->in_w * this->image_w;
box.x2 = box.x2 / this->in_w * this->image_w;
box.y1 = box.y1 / this->in_h * this->image_h;
box.y2 = box.y2 / this->in_h * this->image_h;
result_list.push_back(box);
}
}
return 0;
}
void PicoDet::nms_boxes(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred,
float score_threshold,
std::vector<std::vector<BoxInfo>> &result_list) {
BoxInfo bbox;
int i, j;
for (i = 0; i < dis_pred.h; i++) {
bbox.x1 = dis_pred.row(i)[0];
bbox.y1 = dis_pred.row(i)[1];
bbox.x2 = dis_pred.row(i)[2];
bbox.y2 = dis_pred.row(i)[3];
const float *scores = cls_pred.row(i);
float score = 0;
int cur_label = 0;
for (int label = 0; label < this->num_class; label++) {
float score_ = cls_pred.row(label)[i];
if (score_ > score) {
score = score_;
cur_label = label;
}
}
bbox.score = score;
bbox.label = cur_label;
result_list[cur_label].push_back(bbox);
}
}
void PicoDet::decode_infer(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred, int stride,
float threshold,
std::vector<std::vector<BoxInfo>> &results) {
int feature_h = ceil((float)this->in_w / stride);
int feature_w = ceil((float)this->in_h / stride);
for (int idx = 0; idx < feature_h * feature_w; idx++) {
const float *scores = cls_pred.row(idx);
int row = idx / feature_w;
int col = idx % feature_w;
float score = 0;
int cur_label = 0;
for (int label = 0; label < this->num_class; label++) {
if (scores[label] > score) {
score = scores[label];
cur_label = label;
}
}
if (score > threshold) {
const float *bbox_pred = dis_pred.row(idx);
results[cur_label].push_back(
this->disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
}
}
}
BoxInfo PicoDet::disPred2Bbox(const float *&dfl_det, int label, float score,
int x, int y, int stride) {
float ct_x = (x + 0.5) * stride;
float ct_y = (y + 0.5) * stride;
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++) {
float dis = 0;
float *dis_after_sm = new float[this->reg_max + 1];
activation_function_softmax(dfl_det + i * (this->reg_max + 1), dis_after_sm,
this->reg_max + 1);
for (int j = 0; j < this->reg_max + 1; j++) {
dis += j * dis_after_sm[j];
}
dis *= stride;
dis_pred[i] = dis;
delete[] dis_after_sm;
}
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
float xmax = (std::min)(ct_x + dis_pred[2], (float)this->in_w);
float ymax = (std::min)(ct_y + dis_pred[3], (float)this->in_w);
return BoxInfo{xmin, ymin, xmax, ymax, score, label};
}
void PicoDet::nms(std::vector<BoxInfo> &input_boxes, float NMS_THRESH) {
std::sort(input_boxes.begin(), input_boxes.end(),
[](BoxInfo a, BoxInfo b) { return a.score > b.score; });
std::vector<float> vArea(input_boxes.size());
for (int i = 0; i < int(input_boxes.size()); ++i) {
vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1) *
(input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
}
for (int i = 0; i < int(input_boxes.size()); ++i) {
for (int j = i + 1; j < int(input_boxes.size());) {
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
float yy2 = (std::min)(input_boxes[i].y2, input_boxes[j].y2);
float w = (std::max)(float(0), xx2 - xx1 + 1);
float h = (std::max)(float(0), yy2 - yy1 + 1);
float inter = w * h;
float ovr = inter / (vArea[i] + vArea[j] - inter);
if (ovr >= NMS_THRESH) {
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
} else {
j++;
}
}
}
}

View File

@@ -0,0 +1,87 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from https://github.com/RangiLyu/nanodet/tree/main/demo_ncnn
#ifndef PICODET_H
#define PICODET_H
#include <net.h>
#include <opencv2/core/core.hpp>
typedef struct NonPostProcessHeadInfo {
std::string cls_layer;
std::string dis_layer;
int stride;
} NonPostProcessHeadInfo;
typedef struct BoxInfo {
float x1;
float y1;
float x2;
float y2;
float score;
int label;
} BoxInfo;
class PicoDet {
public:
PicoDet(const char *param, const char *bin, int input_width, int input_hight,
bool useGPU, float score_threshold_, float nms_threshold_);
~PicoDet();
static PicoDet *detector;
ncnn::Net *Net;
static bool hasGPU;
int detect(cv::Mat image, std::vector<BoxInfo> &result_list,
bool has_postprocess);
private:
void preprocess(cv::Mat &image, ncnn::Mat &in);
void decode_infer(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred, int stride,
float threshold,
std::vector<std::vector<BoxInfo>> &results);
BoxInfo disPred2Bbox(const float *&dfl_det, int label, float score, int x,
int y, int stride);
static void nms(std::vector<BoxInfo> &result, float nms_threshold);
void nms_boxes(ncnn::Mat &cls_pred, ncnn::Mat &dis_pred,
float score_threshold,
std::vector<std::vector<BoxInfo>> &result_list);
int image_w;
int image_h;
int in_w = 320;
int in_h = 320;
int num_class = 80;
int reg_max = 7;
float score_threshold;
float nms_threshold;
std::vector<float> bbox_output_data_;
std::vector<float> class_output_data_;
std::vector<std::string> nms_heads_info{"tmp_16", "concat_4.tmp_0"};
// If not export post-process, will use non_postprocess_heads_info
std::vector<NonPostProcessHeadInfo> non_postprocess_heads_info{
// cls_pred|dis_pred|stride
{"transpose_0.tmp_0", "transpose_1.tmp_0", 8},
{"transpose_2.tmp_0", "transpose_3.tmp_0", 16},
{"transpose_4.tmp_0", "transpose_5.tmp_0", 32},
{"transpose_6.tmp_0", "transpose_7.tmp_0", 64},
};
};
#endif