更换文档检测模型
This commit is contained in:
35
paddle_detection/ppdet/ext_op/README.md
Normal file
35
paddle_detection/ppdet/ext_op/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# 自定义OP编译
|
||||
旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html) 。
|
||||
|
||||
## 1. 环境依赖
|
||||
- Paddle >= 2.0.1
|
||||
- gcc 8.2
|
||||
|
||||
## 2. 安装
|
||||
```
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
编译完成后即可使用,以下为`rbox_iou`的使用示例
|
||||
```
|
||||
# 引入自定义op
|
||||
from ext_op import rbox_iou
|
||||
|
||||
paddle.set_device('gpu:0')
|
||||
paddle.disable_static()
|
||||
|
||||
rbox1 = np.random.rand(13000, 5)
|
||||
rbox2 = np.random.rand(7, 5)
|
||||
|
||||
pd_rbox1 = paddle.to_tensor(rbox1)
|
||||
pd_rbox2 = paddle.to_tensor(rbox2)
|
||||
|
||||
iou = rbox_iou(pd_rbox1, pd_rbox2)
|
||||
print('iou', iou)
|
||||
```
|
||||
|
||||
## 3. 单元测试
|
||||
可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
|
||||
```
|
||||
python unittest/test_matched_rbox_iou.py
|
||||
```
|
||||
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// The code is based on
|
||||
// https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
|
||||
|
||||
#include "../rbox_iou/rbox_iou_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T>
|
||||
void matched_rbox_iou_cpu_kernel(const int rbox_num, const T *rbox1_data_ptr,
|
||||
const T *rbox2_data_ptr, T *output_data_ptr) {
|
||||
|
||||
int i;
|
||||
for (i = 0; i < rbox_num; i++) {
|
||||
output_data_ptr[i] =
|
||||
rbox_iou_single<T>(rbox1_data_ptr + i * 5, rbox2_data_ptr + i * 5);
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_INPUT_CPU(x) \
|
||||
PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MatchedRboxIouCPUForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2) {
|
||||
CHECK_INPUT_CPU(rbox1);
|
||||
CHECK_INPUT_CPU(rbox2);
|
||||
PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim");
|
||||
|
||||
auto rbox_num = rbox1.shape()[0];
|
||||
auto output = paddle::empty({rbox_num}, rbox1.dtype(), paddle::CPUPlace());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(rbox1.type(), "matched_rbox_iou_cpu_kernel", ([&] {
|
||||
matched_rbox_iou_cpu_kernel<data_t>(
|
||||
rbox_num, rbox1.data<data_t>(),
|
||||
rbox2.data<data_t>(), output.data<data_t>());
|
||||
}));
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::vector<paddle::Tensor>
|
||||
MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2);
|
||||
#endif
|
||||
|
||||
#define CHECK_INPUT_SAME(x1, x2) \
|
||||
PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
|
||||
|
||||
std::vector<paddle::Tensor> MatchedRboxIouForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2) {
|
||||
CHECK_INPUT_SAME(rbox1, rbox2);
|
||||
if (rbox1.is_cpu()) {
|
||||
return MatchedRboxIouCPUForward(rbox1, rbox2);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (rbox1.is_gpu()) {
|
||||
return MatchedRboxIouCUDAForward(rbox1, rbox2);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
MatchedRboxIouInferShape(std::vector<int64_t> rbox1_shape,
|
||||
std::vector<int64_t> rbox2_shape) {
|
||||
return {{rbox1_shape[0]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MatchedRboxIouInferDtype(paddle::DataType t1,
|
||||
paddle::DataType t2) {
|
||||
return {t1};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(matched_rbox_iou)
|
||||
.Inputs({"RBOX1", "RBOX2"})
|
||||
.Outputs({"Output"})
|
||||
.SetKernelFn(PD_KERNEL(MatchedRboxIouForward))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MatchedRboxIouInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MatchedRboxIouInferDtype));
|
||||
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// The code is based on
|
||||
// https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
|
||||
|
||||
#include "../rbox_iou/rbox_iou_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
matched_rbox_iou_cuda_kernel(const int rbox_num, const T *rbox1_data_ptr,
|
||||
const T *rbox2_data_ptr, T *output_data_ptr) {
|
||||
for (int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < rbox_num;
|
||||
tid += blockDim.x * gridDim.x) {
|
||||
output_data_ptr[tid] =
|
||||
rbox_iou_single<T>(rbox1_data_ptr + tid * 5, rbox2_data_ptr + tid * 5);
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_INPUT_GPU(x) \
|
||||
PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2) {
|
||||
CHECK_INPUT_GPU(rbox1);
|
||||
CHECK_INPUT_GPU(rbox2);
|
||||
PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim");
|
||||
|
||||
auto rbox_num = rbox1.shape()[0];
|
||||
|
||||
auto output = paddle::empty({rbox_num}, rbox1.dtype(), paddle::GPUPlace());
|
||||
|
||||
const int thread_per_block = 512;
|
||||
const int block_per_grid = CeilDiv(rbox_num, thread_per_block);
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
rbox1.type(), "matched_rbox_iou_cuda_kernel", ([&] {
|
||||
matched_rbox_iou_cuda_kernel<
|
||||
data_t><<<block_per_grid, thread_per_block, 0, rbox1.stream()>>>(
|
||||
rbox_num, rbox1.data<data_t>(), rbox2.data<data_t>(),
|
||||
output.data<data_t>());
|
||||
}));
|
||||
|
||||
return {output};
|
||||
}
|
||||
121
paddle_detection/ppdet/ext_op/csrc/nms_rotated/nms_rotated.cc
Normal file
121
paddle_detection/ppdet/ext_op/csrc/nms_rotated/nms_rotated.cc
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "../rbox_iou/rbox_iou_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T>
|
||||
void nms_rotated_cpu_kernel(const T *boxes_data, const float threshold,
|
||||
const int64_t num_boxes, int64_t *num_keep_boxes,
|
||||
int64_t *output_data) {
|
||||
|
||||
int num_masks = CeilDiv(num_boxes, 64);
|
||||
std::vector<int64_t> masks(num_masks, 0);
|
||||
for (int64_t i = 0; i < num_boxes; ++i) {
|
||||
if (masks[i / 64] & 1ULL << (i % 64))
|
||||
continue;
|
||||
T box_1[5];
|
||||
for (int k = 0; k < 5; ++k) {
|
||||
box_1[k] = boxes_data[i * 5 + k];
|
||||
}
|
||||
for (int64_t j = i + 1; j < num_boxes; ++j) {
|
||||
if (masks[j / 64] & 1ULL << (j % 64))
|
||||
continue;
|
||||
T box_2[5];
|
||||
for (int k = 0; k < 5; ++k) {
|
||||
box_2[k] = boxes_data[j * 5 + k];
|
||||
}
|
||||
if (rbox_iou_single<T>(box_1, box_2) > threshold) {
|
||||
masks[j / 64] |= 1ULL << (j % 64);
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t output_data_idx = 0;
|
||||
for (int64_t i = 0; i < num_boxes; ++i) {
|
||||
if (masks[i / 64] & 1ULL << (i % 64))
|
||||
continue;
|
||||
output_data[output_data_idx++] = i;
|
||||
}
|
||||
*num_keep_boxes = output_data_idx;
|
||||
for (; output_data_idx < num_boxes; ++output_data_idx) {
|
||||
output_data[output_data_idx] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_INPUT_CPU(x) \
|
||||
PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
|
||||
|
||||
std::vector<paddle::Tensor> NMSRotatedCPUForward(const paddle::Tensor &boxes,
|
||||
const paddle::Tensor &scores,
|
||||
float threshold) {
|
||||
CHECK_INPUT_CPU(boxes);
|
||||
CHECK_INPUT_CPU(scores);
|
||||
|
||||
auto num_boxes = boxes.shape()[0];
|
||||
|
||||
auto order_t =
|
||||
std::get<1>(paddle::argsort(scores, /* axis=*/0, /* descending=*/true));
|
||||
auto boxes_sorted = paddle::gather(boxes, order_t, /* axis=*/0);
|
||||
|
||||
auto keep =
|
||||
paddle::empty({num_boxes}, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
int64_t num_keep_boxes = 0;
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(boxes.type(), "nms_rotated_cpu_kernel", ([&] {
|
||||
nms_rotated_cpu_kernel<data_t>(
|
||||
boxes_sorted.data<data_t>(), threshold,
|
||||
num_boxes, &num_keep_boxes,
|
||||
keep.data<int64_t>());
|
||||
}));
|
||||
|
||||
keep = keep.slice(0, num_keep_boxes);
|
||||
return {paddle::gather(order_t, keep, /* axis=*/0)};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::vector<paddle::Tensor> NMSRotatedCUDAForward(const paddle::Tensor &boxes,
|
||||
const paddle::Tensor &scores,
|
||||
float threshold);
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> NMSRotatedForward(const paddle::Tensor &boxes,
|
||||
const paddle::Tensor &scores,
|
||||
float threshold) {
|
||||
if (boxes.is_cpu()) {
|
||||
return NMSRotatedCPUForward(boxes, scores, threshold);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (boxes.is_gpu()) {
|
||||
return NMSRotatedCUDAForward(boxes, scores, threshold);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
NMSRotatedInferShape(std::vector<int64_t> boxes_shape,
|
||||
std::vector<int64_t> scores_shape) {
|
||||
return {{-1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> NMSRotatedInferDtype(paddle::DataType t1,
|
||||
paddle::DataType t2) {
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(nms_rotated)
|
||||
.Inputs({"Boxes", "Scores"})
|
||||
.Outputs({"Output"})
|
||||
.Attrs({"threshold: float"})
|
||||
.SetKernelFn(PD_KERNEL(NMSRotatedForward))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NMSRotatedInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(NMSRotatedInferDtype));
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "../rbox_iou/rbox_iou_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
static const int64_t threadsPerBlock = sizeof(int64_t) * 8;
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
nms_rotated_cuda_kernel(const T *boxes_data, const float threshold,
|
||||
const int64_t num_boxes, int64_t *masks) {
|
||||
auto raw_start = blockIdx.y;
|
||||
auto col_start = blockIdx.x;
|
||||
if (raw_start > col_start)
|
||||
return;
|
||||
const int raw_last_storage =
|
||||
min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock);
|
||||
const int col_last_storage =
|
||||
min(num_boxes - col_start * threadsPerBlock, threadsPerBlock);
|
||||
if (threadIdx.x < raw_last_storage) {
|
||||
int64_t mask = 0;
|
||||
auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x;
|
||||
const T *current_box = boxes_data + current_box_idx * 5;
|
||||
for (int i = 0; i < col_last_storage; ++i) {
|
||||
const T *target_box = boxes_data + (col_start * threadsPerBlock + i) * 5;
|
||||
if (rbox_iou_single<T>(current_box, target_box) > threshold) {
|
||||
mask |= 1ULL << i;
|
||||
}
|
||||
}
|
||||
const int blocks_per_line = CeilDiv(num_boxes, threadsPerBlock);
|
||||
masks[current_box_idx * blocks_per_line + col_start] = mask;
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_INPUT_GPU(x) \
|
||||
PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
|
||||
|
||||
std::vector<paddle::Tensor> NMSRotatedCUDAForward(const paddle::Tensor &boxes,
|
||||
const paddle::Tensor &scores,
|
||||
float threshold) {
|
||||
CHECK_INPUT_GPU(boxes);
|
||||
CHECK_INPUT_GPU(scores);
|
||||
|
||||
auto num_boxes = boxes.shape()[0];
|
||||
auto order_t =
|
||||
std::get<1>(paddle::argsort(scores, /* axis=*/0, /* descending=*/true));
|
||||
auto boxes_sorted = paddle::gather(boxes, order_t, /* axis=*/0);
|
||||
|
||||
const auto blocks_per_line = CeilDiv(num_boxes, threadsPerBlock);
|
||||
dim3 block(threadsPerBlock);
|
||||
dim3 grid(blocks_per_line, blocks_per_line);
|
||||
auto mask_dev = paddle::empty({num_boxes * blocks_per_line},
|
||||
paddle::DataType::INT64, paddle::GPUPlace());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
boxes.type(), "nms_rotated_cuda_kernel", ([&] {
|
||||
nms_rotated_cuda_kernel<data_t><<<grid, block, 0, boxes.stream()>>>(
|
||||
boxes_sorted.data<data_t>(), threshold, num_boxes,
|
||||
mask_dev.data<int64_t>());
|
||||
}));
|
||||
|
||||
auto mask_host = mask_dev.copy_to(paddle::CPUPlace(), true);
|
||||
auto keep_host =
|
||||
paddle::empty({num_boxes}, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
int64_t *keep_host_ptr = keep_host.data<int64_t>();
|
||||
int64_t *mask_host_ptr = mask_host.data<int64_t>();
|
||||
std::vector<int64_t> remv(blocks_per_line);
|
||||
int64_t last_box_num = 0;
|
||||
for (int64_t i = 0; i < num_boxes; ++i) {
|
||||
auto remv_element_id = i / threadsPerBlock;
|
||||
auto remv_bit_id = i % threadsPerBlock;
|
||||
if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) {
|
||||
keep_host_ptr[last_box_num++] = i;
|
||||
int64_t *current_mask = mask_host_ptr + i * blocks_per_line;
|
||||
for (auto j = remv_element_id; j < blocks_per_line; ++j) {
|
||||
remv[j] |= current_mask[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keep_host = keep_host.slice(0, last_box_num);
|
||||
auto keep_dev = keep_host.copy_to(paddle::GPUPlace(), true);
|
||||
return {paddle::gather(order_t, keep_dev, /* axis=*/0)};
|
||||
}
|
||||
95
paddle_detection/ppdet/ext_op/csrc/rbox_iou/rbox_iou.cc
Normal file
95
paddle_detection/ppdet/ext_op/csrc/rbox_iou/rbox_iou.cc
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// The code is based on
|
||||
// https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "rbox_iou_utils.h"
|
||||
|
||||
template <typename T>
|
||||
void rbox_iou_cpu_kernel(const int rbox1_num, const int rbox2_num,
|
||||
const T *rbox1_data_ptr, const T *rbox2_data_ptr,
|
||||
T *output_data_ptr) {
|
||||
|
||||
int i, j;
|
||||
for (i = 0; i < rbox1_num; i++) {
|
||||
for (j = 0; j < rbox2_num; j++) {
|
||||
int offset = i * rbox2_num + j;
|
||||
output_data_ptr[offset] =
|
||||
rbox_iou_single<T>(rbox1_data_ptr + i * 5, rbox2_data_ptr + j * 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_INPUT_CPU(x) \
|
||||
PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
|
||||
|
||||
std::vector<paddle::Tensor> RboxIouCPUForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2) {
|
||||
CHECK_INPUT_CPU(rbox1);
|
||||
CHECK_INPUT_CPU(rbox2);
|
||||
|
||||
auto rbox1_num = rbox1.shape()[0];
|
||||
auto rbox2_num = rbox2.shape()[0];
|
||||
|
||||
auto output =
|
||||
paddle::empty({rbox1_num, rbox2_num}, rbox1.dtype(), paddle::CPUPlace());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(rbox1.type(), "rbox_iou_cpu_kernel", ([&] {
|
||||
rbox_iou_cpu_kernel<data_t>(
|
||||
rbox1_num, rbox2_num, rbox1.data<data_t>(),
|
||||
rbox2.data<data_t>(), output.data<data_t>());
|
||||
}));
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2);
|
||||
#endif
|
||||
|
||||
#define CHECK_INPUT_SAME(x1, x2) \
|
||||
PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
|
||||
|
||||
std::vector<paddle::Tensor> RboxIouForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2) {
|
||||
CHECK_INPUT_SAME(rbox1, rbox2);
|
||||
if (rbox1.is_cpu()) {
|
||||
return RboxIouCPUForward(rbox1, rbox2);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (rbox1.is_gpu()) {
|
||||
return RboxIouCUDAForward(rbox1, rbox2);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
RboxIouInferShape(std::vector<int64_t> rbox1_shape,
|
||||
std::vector<int64_t> rbox2_shape) {
|
||||
return {{rbox1_shape[0], rbox2_shape[0]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RboxIouInferDtype(paddle::DataType t1,
|
||||
paddle::DataType t2) {
|
||||
return {t1};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(rbox_iou)
|
||||
.Inputs({"RBox1", "RBox2"})
|
||||
.Outputs({"Output"})
|
||||
.SetKernelFn(PD_KERNEL(RboxIouForward))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RboxIouInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RboxIouInferDtype));
|
||||
109
paddle_detection/ppdet/ext_op/csrc/rbox_iou/rbox_iou.cu
Normal file
109
paddle_detection/ppdet/ext_op/csrc/rbox_iou/rbox_iou.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// The code is based on
|
||||
// https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "rbox_iou_utils.h"
|
||||
|
||||
// 2D block with 32 * 16 = 512 threads per block
|
||||
const int BLOCK_DIM_X = 32;
|
||||
const int BLOCK_DIM_Y = 16;
|
||||
|
||||
template <typename T>
|
||||
__global__ void rbox_iou_cuda_kernel(const int rbox1_num, const int rbox2_num,
|
||||
const T *rbox1_data_ptr,
|
||||
const T *rbox2_data_ptr,
|
||||
T *output_data_ptr) {
|
||||
|
||||
// get row_start and col_start
|
||||
const int rbox1_block_idx = blockIdx.x * blockDim.x;
|
||||
const int rbox2_block_idx = blockIdx.y * blockDim.y;
|
||||
|
||||
const int rbox1_thread_num = min(rbox1_num - rbox1_block_idx, blockDim.x);
|
||||
const int rbox2_thread_num = min(rbox2_num - rbox2_block_idx, blockDim.y);
|
||||
|
||||
__shared__ T block_boxes1[BLOCK_DIM_X * 5];
|
||||
__shared__ T block_boxes2[BLOCK_DIM_Y * 5];
|
||||
|
||||
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
|
||||
if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) {
|
||||
block_boxes1[threadIdx.x * 5 + 0] =
|
||||
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 0];
|
||||
block_boxes1[threadIdx.x * 5 + 1] =
|
||||
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 1];
|
||||
block_boxes1[threadIdx.x * 5 + 2] =
|
||||
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 2];
|
||||
block_boxes1[threadIdx.x * 5 + 3] =
|
||||
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 3];
|
||||
block_boxes1[threadIdx.x * 5 + 4] =
|
||||
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4];
|
||||
}
|
||||
|
||||
// threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as
|
||||
// above: threadIdx.y == 0
|
||||
if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) {
|
||||
block_boxes2[threadIdx.x * 5 + 0] =
|
||||
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0];
|
||||
block_boxes2[threadIdx.x * 5 + 1] =
|
||||
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 1];
|
||||
block_boxes2[threadIdx.x * 5 + 2] =
|
||||
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 2];
|
||||
block_boxes2[threadIdx.x * 5 + 3] =
|
||||
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 3];
|
||||
block_boxes2[threadIdx.x * 5 + 4] =
|
||||
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 4];
|
||||
}
|
||||
|
||||
// sync
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) {
|
||||
int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx +
|
||||
threadIdx.y;
|
||||
output_data_ptr[offset] = rbox_iou_single<T>(
|
||||
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_INPUT_GPU(x) \
|
||||
PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
|
||||
|
||||
std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor &rbox1,
|
||||
const paddle::Tensor &rbox2) {
|
||||
CHECK_INPUT_GPU(rbox1);
|
||||
CHECK_INPUT_GPU(rbox2);
|
||||
|
||||
auto rbox1_num = rbox1.shape()[0];
|
||||
auto rbox2_num = rbox2.shape()[0];
|
||||
|
||||
auto output =
|
||||
paddle::empty({rbox1_num, rbox2_num}, rbox1.dtype(), paddle::GPUPlace());
|
||||
|
||||
const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X);
|
||||
const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y);
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y);
|
||||
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
rbox1.type(), "rbox_iou_cuda_kernel", ([&] {
|
||||
rbox_iou_cuda_kernel<data_t><<<blocks, threads, 0, rbox1.stream()>>>(
|
||||
rbox1_num, rbox2_num, rbox1.data<data_t>(), rbox2.data<data_t>(),
|
||||
output.data<data_t>());
|
||||
}));
|
||||
|
||||
return {output};
|
||||
}
|
||||
356
paddle_detection/ppdet/ext_op/csrc/rbox_iou/rbox_iou_utils.h
Normal file
356
paddle_detection/ppdet/ext_op/csrc/rbox_iou/rbox_iou_utils.h
Normal file
@@ -0,0 +1,356 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// The code is based on
|
||||
// https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
// Designates functions callable from the host (CPU) and the device (GPU)
|
||||
#define HOST_DEVICE __host__ __device__
|
||||
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
|
||||
#else
|
||||
#include <algorithm>
|
||||
#define HOST_DEVICE
|
||||
#define HOST_DEVICE_INLINE HOST_DEVICE inline
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T> struct RotatedBox { T x_ctr, y_ctr, w, h, a; };
|
||||
|
||||
template <typename T> struct Point {
|
||||
T x, y;
|
||||
HOST_DEVICE_INLINE Point(const T &px = 0, const T &py = 0) : x(px), y(py) {}
|
||||
HOST_DEVICE_INLINE Point operator+(const Point &p) const {
|
||||
return Point(x + p.x, y + p.y);
|
||||
}
|
||||
HOST_DEVICE_INLINE Point &operator+=(const Point &p) {
|
||||
x += p.x;
|
||||
y += p.y;
|
||||
return *this;
|
||||
}
|
||||
HOST_DEVICE_INLINE Point operator-(const Point &p) const {
|
||||
return Point(x - p.x, y - p.y);
|
||||
}
|
||||
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
|
||||
return Point(x * coeff, y * coeff);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE T dot_2d(const Point<T> &A, const Point<T> &B) {
|
||||
return A.x * B.x + A.y * B.y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE T cross_2d(const Point<T> &A, const Point<T> &B) {
|
||||
return A.x * B.y - B.x * A.y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE void get_rotated_vertices(const RotatedBox<T> &box,
|
||||
Point<T> (&pts)[4]) {
|
||||
// M_PI / 180. == 0.01745329251
|
||||
// double theta = box.a * 0.01745329251;
|
||||
// MODIFIED
|
||||
double theta = box.a;
|
||||
T cosTheta2 = (T)cos(theta) * 0.5f;
|
||||
T sinTheta2 = (T)sin(theta) * 0.5f;
|
||||
|
||||
// y: top --> down; x: left --> right
|
||||
pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
|
||||
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
|
||||
pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
|
||||
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
|
||||
pts[2].x = 2 * box.x_ctr - pts[0].x;
|
||||
pts[2].y = 2 * box.y_ctr - pts[0].y;
|
||||
pts[3].x = 2 * box.x_ctr - pts[1].x;
|
||||
pts[3].y = 2 * box.y_ctr - pts[1].y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE int get_intersection_points(const Point<T> (&pts1)[4],
|
||||
const Point<T> (&pts2)[4],
|
||||
Point<T> (&intersections)[24]) {
|
||||
// Line vector
|
||||
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
|
||||
Point<T> vec1[4], vec2[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
|
||||
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
|
||||
}
|
||||
|
||||
// Line test - test all line combos for intersection
|
||||
int num = 0; // number of intersections
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 4; j++) {
|
||||
// Solve for 2x2 Ax=b
|
||||
T det = cross_2d<T>(vec2[j], vec1[i]);
|
||||
|
||||
// This takes care of parallel lines
|
||||
if (fabs(det) <= 1e-14) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto vec12 = pts2[j] - pts1[i];
|
||||
|
||||
T t1 = cross_2d<T>(vec2[j], vec12) / det;
|
||||
T t2 = cross_2d<T>(vec1[i], vec12) / det;
|
||||
|
||||
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
|
||||
intersections[num++] = pts1[i] + vec1[i] * t1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for vertices of rect1 inside rect2
|
||||
{
|
||||
const auto &AB = vec2[0];
|
||||
const auto &DA = vec2[3];
|
||||
auto ABdotAB = dot_2d<T>(AB, AB);
|
||||
auto ADdotAD = dot_2d<T>(DA, DA);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
// assume ABCD is the rectangle, and P is the point to be judged
|
||||
// P is inside ABCD iff. P's projection on AB lies within AB
|
||||
// and P's projection on AD lies within AD
|
||||
|
||||
auto AP = pts1[i] - pts2[0];
|
||||
|
||||
auto APdotAB = dot_2d<T>(AP, AB);
|
||||
auto APdotAD = -dot_2d<T>(AP, DA);
|
||||
|
||||
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
|
||||
(APdotAD <= ADdotAD)) {
|
||||
intersections[num++] = pts1[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse the check - check for vertices of rect2 inside rect1
|
||||
{
|
||||
const auto &AB = vec1[0];
|
||||
const auto &DA = vec1[3];
|
||||
auto ABdotAB = dot_2d<T>(AB, AB);
|
||||
auto ADdotAD = dot_2d<T>(DA, DA);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto AP = pts2[i] - pts1[0];
|
||||
|
||||
auto APdotAB = dot_2d<T>(AP, AB);
|
||||
auto APdotAD = -dot_2d<T>(AP, DA);
|
||||
|
||||
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
|
||||
(APdotAD <= ADdotAD)) {
|
||||
intersections[num++] = pts2[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return num;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
|
||||
const int &num_in, Point<T> (&q)[24],
|
||||
bool shift_to_zero = false) {
|
||||
assert(num_in >= 2);
|
||||
|
||||
// Step 1:
|
||||
// Find point with minimum y
|
||||
// if more than 1 points have the same minimum y,
|
||||
// pick the one with the minimum x.
|
||||
int t = 0;
|
||||
for (int i = 1; i < num_in; i++) {
|
||||
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
|
||||
t = i;
|
||||
}
|
||||
}
|
||||
auto &start = p[t]; // starting point
|
||||
|
||||
// Step 2:
|
||||
// Subtract starting point from every points (for sorting in the next step)
|
||||
for (int i = 0; i < num_in; i++) {
|
||||
q[i] = p[i] - start;
|
||||
}
|
||||
|
||||
// Swap the starting point to position 0
|
||||
auto tmp = q[0];
|
||||
q[0] = q[t];
|
||||
q[t] = tmp;
|
||||
|
||||
// Step 3:
|
||||
// Sort point 1 ~ num_in according to their relative cross-product values
|
||||
// (essentially sorting according to angles)
|
||||
// If the angles are the same, sort according to their distance to origin
|
||||
T dist[24];
|
||||
for (int i = 0; i < num_in; i++) {
|
||||
dist[i] = dot_2d<T>(q[i], q[i]);
|
||||
}
|
||||
|
||||
#ifdef __CUDACC__
|
||||
// CUDA version
|
||||
// In the future, we can potentially use thrust
|
||||
// for sorting here to improve speed (though not guaranteed)
|
||||
for (int i = 1; i < num_in - 1; i++) {
|
||||
for (int j = i + 1; j < num_in; j++) {
|
||||
T crossProduct = cross_2d<T>(q[i], q[j]);
|
||||
if ((crossProduct < -1e-6) ||
|
||||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
|
||||
auto q_tmp = q[i];
|
||||
q[i] = q[j];
|
||||
q[j] = q_tmp;
|
||||
auto dist_tmp = dist[i];
|
||||
dist[i] = dist[j];
|
||||
dist[j] = dist_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// CPU version
|
||||
std::sort(q + 1, q + num_in,
|
||||
[](const Point<T> &A, const Point<T> &B) -> bool {
|
||||
T temp = cross_2d<T>(A, B);
|
||||
if (fabs(temp) < 1e-6) {
|
||||
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
|
||||
} else {
|
||||
return temp > 0;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
// Step 4:
|
||||
// Make sure there are at least 2 points (that don't overlap with each other)
|
||||
// in the stack
|
||||
int k; // index of the non-overlapped second point
|
||||
for (k = 1; k < num_in; k++) {
|
||||
if (dist[k] > 1e-8) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == num_in) {
|
||||
// We reach the end, which means the convex hull is just one point
|
||||
q[0] = p[t];
|
||||
return 1;
|
||||
}
|
||||
q[1] = q[k];
|
||||
int m = 2; // 2 points in the stack
|
||||
// Step 5:
|
||||
// Finally we can start the scanning process.
|
||||
// When a non-convex relationship between the 3 points is found
|
||||
// (either concave shape or duplicated points),
|
||||
// we pop the previous point from the stack
|
||||
// until the 3-point relationship is convex again, or
|
||||
// until the stack only contains two points
|
||||
for (int i = k + 1; i < num_in; i++) {
|
||||
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
|
||||
m--;
|
||||
}
|
||||
q[m++] = q[i];
|
||||
}
|
||||
|
||||
// Step 6 (Optional):
|
||||
// In general sense we need the original coordinates, so we
|
||||
// need to shift the points back (reverting Step 2)
|
||||
// But if we're only interested in getting the area/perimeter of the shape
|
||||
// We can simply return.
|
||||
if (!shift_to_zero) {
|
||||
for (int i = 0; i < m; i++) {
|
||||
q[i] += start;
|
||||
}
|
||||
}
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int &m) {
|
||||
if (m <= 2) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
T area = 0;
|
||||
for (int i = 1; i < m - 1; i++) {
|
||||
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
|
||||
}
|
||||
|
||||
return area / 2.0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE T rboxes_intersection(const RotatedBox<T> &box1,
|
||||
const RotatedBox<T> &box2) {
|
||||
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
|
||||
// from rotated_rect_intersection_pts
|
||||
Point<T> intersectPts[24], orderedPts[24];
|
||||
|
||||
Point<T> pts1[4];
|
||||
Point<T> pts2[4];
|
||||
get_rotated_vertices<T>(box1, pts1);
|
||||
get_rotated_vertices<T>(box2, pts2);
|
||||
|
||||
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
|
||||
|
||||
if (num <= 2) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Convex Hull to order the intersection points in clockwise order and find
|
||||
// the contour area.
|
||||
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
|
||||
return polygon_area<T>(orderedPts, num_convex);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE T rbox_iou_single(T const *const box1_raw,
|
||||
T const *const box2_raw) {
|
||||
// shift center to the middle point to achieve higher precision in result
|
||||
RotatedBox<T> box1, box2;
|
||||
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
|
||||
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
|
||||
box1.x_ctr = box1_raw[0] - center_shift_x;
|
||||
box1.y_ctr = box1_raw[1] - center_shift_y;
|
||||
box1.w = box1_raw[2];
|
||||
box1.h = box1_raw[3];
|
||||
box1.a = box1_raw[4];
|
||||
box2.x_ctr = box2_raw[0] - center_shift_x;
|
||||
box2.y_ctr = box2_raw[1] - center_shift_y;
|
||||
box2.w = box2_raw[2];
|
||||
box2.h = box2_raw[3];
|
||||
box2.a = box2_raw[4];
|
||||
|
||||
if (box1.w < 1e-2 || box1.h < 1e-2 || box2.w < 1e-2 || box2.h < 1e-2) {
|
||||
return 0.f;
|
||||
}
|
||||
const T area1 = box1.w * box1.h;
|
||||
const T area2 = box2.w * box2.h;
|
||||
|
||||
const T intersection = rboxes_intersection<T>(box1, box2);
|
||||
const T iou = intersection / (area1 + area2 - intersection);
|
||||
return iou;
|
||||
}
|
||||
|
||||
/**
|
||||
Computes ceil(a / b)
|
||||
*/
|
||||
|
||||
HOST_DEVICE inline int CeilDiv(const int a, const int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
33
paddle_detection/ppdet/ext_op/setup.py
Normal file
33
paddle_detection/ppdet/ext_op/setup.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import glob
|
||||
import paddle
|
||||
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
|
||||
|
||||
|
||||
def get_extensions():
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_root_dir = os.path.join(root_dir, 'csrc')
|
||||
sources = []
|
||||
for ext_name in os.listdir(ext_root_dir):
|
||||
ext_dir = os.path.join(ext_root_dir, ext_name)
|
||||
source = glob.glob(os.path.join(ext_dir, '*.cc'))
|
||||
kwargs = dict()
|
||||
if paddle.device.is_compiled_with_cuda():
|
||||
source += glob.glob(os.path.join(ext_dir, '*.cu'))
|
||||
|
||||
if not source:
|
||||
continue
|
||||
|
||||
sources += source
|
||||
|
||||
if paddle.device.is_compiled_with_cuda():
|
||||
extension = CUDAExtension(
|
||||
sources, extra_compile_args={'cxx': ['-DPADDLE_WITH_CUDA']})
|
||||
else:
|
||||
extension = CppExtension(sources)
|
||||
|
||||
return extension
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(name='ext_op', ext_modules=get_extensions())
|
||||
149
paddle_detection/ppdet/ext_op/unittest/test_matched_rbox_iou.py
Normal file
149
paddle_detection/ppdet/ext_op/unittest/test_matched_rbox_iou.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import time
|
||||
from shapely.geometry import Polygon
|
||||
import paddle
|
||||
import unittest
|
||||
|
||||
from ext_op import matched_rbox_iou
|
||||
|
||||
|
||||
def rbox2poly_single(rrect, get_best_begin_point=False):
|
||||
"""
|
||||
rrect:[x_ctr,y_ctr,w,h,angle]
|
||||
to
|
||||
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
|
||||
"""
|
||||
x_ctr, y_ctr, width, height, angle = rrect[:5]
|
||||
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
|
||||
# rect 2x4
|
||||
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
|
||||
R = np.array([[np.cos(angle), -np.sin(angle)],
|
||||
[np.sin(angle), np.cos(angle)]])
|
||||
# poly
|
||||
poly = R.dot(rect)
|
||||
x0, x1, x2, x3 = poly[0, :4] + x_ctr
|
||||
y0, y1, y2, y3 = poly[1, :4] + y_ctr
|
||||
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64)
|
||||
return poly
|
||||
|
||||
|
||||
def intersection(g, p):
|
||||
"""
|
||||
Intersection.
|
||||
"""
|
||||
|
||||
g = g[:8].reshape((4, 2))
|
||||
p = p[:8].reshape((4, 2))
|
||||
|
||||
a = g
|
||||
b = p
|
||||
|
||||
use_filter = True
|
||||
if use_filter:
|
||||
# step1:
|
||||
inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0]))
|
||||
inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0]))
|
||||
inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1]))
|
||||
inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1]))
|
||||
if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
|
||||
return 0.
|
||||
x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0]))
|
||||
x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0]))
|
||||
y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1]))
|
||||
y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1]))
|
||||
if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2:
|
||||
return 0.
|
||||
|
||||
g = Polygon(g)
|
||||
p = Polygon(p)
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
|
||||
inter = Polygon(g).intersection(Polygon(p)).area
|
||||
union = g.area + p.area - inter
|
||||
if union == 0:
|
||||
return 0
|
||||
else:
|
||||
return inter / union
|
||||
|
||||
|
||||
def matched_rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
anchors: [M, 5] x1,y1,x2,y2,angle
|
||||
gt_bboxes: [M, 5] x1,y1,x2,y2,angle
|
||||
|
||||
Returns:
|
||||
macthed_iou: [M]
|
||||
"""
|
||||
assert anchors.shape[1] == 5
|
||||
assert gt_bboxes.shape[1] == 5
|
||||
|
||||
gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes]
|
||||
anchors_ploy = [rbox2poly_single(e) for e in anchors]
|
||||
|
||||
num = len(anchors_ploy)
|
||||
iou = np.zeros((num, ), dtype=np.float64)
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(num):
|
||||
try:
|
||||
iou[i] = intersection(gt_bboxes_ploy[i], anchors_ploy[i])
|
||||
except Exception as e:
|
||||
print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i],
|
||||
'anchors_ploy[j]', anchors_ploy[i], e)
|
||||
return iou
|
||||
|
||||
|
||||
def gen_sample(n):
|
||||
rbox = np.random.rand(n, 5)
|
||||
rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001
|
||||
rbox[:, 4] = rbox[:, 4] - 0.5
|
||||
return rbox
|
||||
|
||||
|
||||
class MatchedRBoxIoUTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.rbox1 = gen_sample(self.n)
|
||||
self.rbox2 = gen_sample(self.n)
|
||||
|
||||
def initTestCase(self):
|
||||
self.n = 1000
|
||||
|
||||
def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2):
|
||||
self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
|
||||
|
||||
def get_places(self):
|
||||
places = [paddle.CPUPlace()]
|
||||
if paddle.device.is_compiled_with_cuda():
|
||||
places.append(paddle.CUDAPlace(0))
|
||||
|
||||
return places
|
||||
|
||||
def check_output(self, place):
|
||||
paddle.disable_static()
|
||||
pd_rbox1 = paddle.to_tensor(self.rbox1, place=place)
|
||||
pd_rbox2 = paddle.to_tensor(self.rbox2, place=place)
|
||||
actual_t = matched_rbox_iou(pd_rbox1, pd_rbox2).numpy()
|
||||
poly_rbox1 = self.rbox1
|
||||
poly_rbox2 = self.rbox2
|
||||
poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024
|
||||
poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024
|
||||
expect_t = matched_rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False)
|
||||
self.assertAllClose(
|
||||
actual_t,
|
||||
expect_t,
|
||||
msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format(
|
||||
str(place), str(expect_t), str(actual_t)))
|
||||
|
||||
def test_output(self):
|
||||
places = self.get_places()
|
||||
for place in places:
|
||||
self.check_output(place)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
151
paddle_detection/ppdet/ext_op/unittest/test_rbox_iou.py
Normal file
151
paddle_detection/ppdet/ext_op/unittest/test_rbox_iou.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import time
|
||||
from shapely.geometry import Polygon
|
||||
import paddle
|
||||
import unittest
|
||||
|
||||
from ext_op import rbox_iou
|
||||
|
||||
|
||||
def rbox2poly_single(rrect, get_best_begin_point=False):
|
||||
"""
|
||||
rrect:[x_ctr,y_ctr,w,h,angle]
|
||||
to
|
||||
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
|
||||
"""
|
||||
x_ctr, y_ctr, width, height, angle = rrect[:5]
|
||||
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
|
||||
# rect 2x4
|
||||
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
|
||||
R = np.array([[np.cos(angle), -np.sin(angle)],
|
||||
[np.sin(angle), np.cos(angle)]])
|
||||
# poly
|
||||
poly = R.dot(rect)
|
||||
x0, x1, x2, x3 = poly[0, :4] + x_ctr
|
||||
y0, y1, y2, y3 = poly[1, :4] + y_ctr
|
||||
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64)
|
||||
return poly
|
||||
|
||||
|
||||
def intersection(g, p):
|
||||
"""
|
||||
Intersection.
|
||||
"""
|
||||
|
||||
g = g[:8].reshape((4, 2))
|
||||
p = p[:8].reshape((4, 2))
|
||||
|
||||
a = g
|
||||
b = p
|
||||
|
||||
use_filter = True
|
||||
if use_filter:
|
||||
# step1:
|
||||
inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0]))
|
||||
inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0]))
|
||||
inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1]))
|
||||
inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1]))
|
||||
if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
|
||||
return 0.
|
||||
x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0]))
|
||||
x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0]))
|
||||
y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1]))
|
||||
y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1]))
|
||||
if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2:
|
||||
return 0.
|
||||
|
||||
g = Polygon(g)
|
||||
p = Polygon(p)
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
|
||||
inter = Polygon(g).intersection(Polygon(p)).area
|
||||
union = g.area + p.area - inter
|
||||
if union == 0:
|
||||
return 0
|
||||
else:
|
||||
return inter / union
|
||||
|
||||
|
||||
def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
anchors: [NA, 5] x1,y1,x2,y2,angle
|
||||
gt_bboxes: [M, 5] x1,y1,x2,y2,angle
|
||||
|
||||
Returns:
|
||||
iou: [NA, M]
|
||||
"""
|
||||
assert anchors.shape[1] == 5
|
||||
assert gt_bboxes.shape[1] == 5
|
||||
|
||||
gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes]
|
||||
anchors_ploy = [rbox2poly_single(e) for e in anchors]
|
||||
|
||||
num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy)
|
||||
iou = np.zeros((num_anchors, num_gt), dtype=np.float64)
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(num_anchors):
|
||||
for j in range(num_gt):
|
||||
try:
|
||||
iou[i, j] = intersection(anchors_ploy[i], gt_bboxes_ploy[j])
|
||||
except Exception as e:
|
||||
print('cur anchors_ploy[i]', anchors_ploy[i],
|
||||
'gt_bboxes_ploy[j]', gt_bboxes_ploy[j], e)
|
||||
return iou
|
||||
|
||||
|
||||
def gen_sample(n):
|
||||
rbox = np.random.rand(n, 5)
|
||||
rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001
|
||||
rbox[:, 4] = rbox[:, 4] - 0.5
|
||||
return rbox
|
||||
|
||||
|
||||
class RBoxIoUTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.rbox1 = gen_sample(self.n)
|
||||
self.rbox2 = gen_sample(self.m)
|
||||
|
||||
def initTestCase(self):
|
||||
self.n = 13000
|
||||
self.m = 7
|
||||
|
||||
def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2):
|
||||
self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
|
||||
|
||||
def get_places(self):
|
||||
places = [paddle.CPUPlace()]
|
||||
if paddle.device.is_compiled_with_cuda():
|
||||
places.append(paddle.CUDAPlace(0))
|
||||
|
||||
return places
|
||||
|
||||
def check_output(self, place):
|
||||
paddle.disable_static()
|
||||
pd_rbox1 = paddle.to_tensor(self.rbox1, place=place)
|
||||
pd_rbox2 = paddle.to_tensor(self.rbox2, place=place)
|
||||
actual_t = rbox_iou(pd_rbox1, pd_rbox2).numpy()
|
||||
poly_rbox1 = self.rbox1
|
||||
poly_rbox2 = self.rbox2
|
||||
poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024
|
||||
poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024
|
||||
expect_t = rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False)
|
||||
self.assertAllClose(
|
||||
actual_t,
|
||||
expect_t,
|
||||
msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format(
|
||||
str(place), str(expect_t), str(actual_t)))
|
||||
|
||||
def test_output(self):
|
||||
places = self.get_places()
|
||||
for place in places:
|
||||
self.check_output(place)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user