更换文档检测模型
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from . import rpn_head
|
||||
from . import embedding_rpn_head
|
||||
|
||||
from .rpn_head import *
|
||||
from .embedding_rpn_head import *
|
||||
@@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The code is based on
|
||||
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/anchor_generator.py
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from ppdet.core.workspace import register
|
||||
|
||||
__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator']
|
||||
|
||||
|
||||
@register
|
||||
class AnchorGenerator(nn.Layer):
|
||||
"""
|
||||
Generate anchors according to the feature maps
|
||||
|
||||
Args:
|
||||
anchor_sizes (list[float] | list[list[float]]): The anchor sizes at
|
||||
each feature point. list[float] means all feature levels share the
|
||||
same sizes. list[list[float]] means the anchor sizes for
|
||||
each level. The sizes stand for the scale of input size.
|
||||
aspect_ratios (list[float] | list[list[float]]): The aspect ratios at
|
||||
each feature point. list[float] means all feature levels share the
|
||||
same ratios. list[list[float]] means the aspect ratios for
|
||||
each level.
|
||||
strides (list[float]): The strides of feature maps which generate
|
||||
anchors
|
||||
offset (float): The offset of the coordinate of anchors, default 0.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
anchor_sizes=[32, 64, 128, 256, 512],
|
||||
aspect_ratios=[0.5, 1.0, 2.0],
|
||||
strides=[16.0],
|
||||
variance=[1.0, 1.0, 1.0, 1.0],
|
||||
offset=0.):
|
||||
super(AnchorGenerator, self).__init__()
|
||||
self.anchor_sizes = anchor_sizes
|
||||
self.aspect_ratios = aspect_ratios
|
||||
self.strides = strides
|
||||
self.variance = variance
|
||||
self.cell_anchors = self._calculate_anchors(len(strides))
|
||||
self.offset = offset
|
||||
|
||||
def _broadcast_params(self, params, num_features):
|
||||
if not isinstance(params[0], (list, tuple)): # list[float]
|
||||
return [params] * num_features
|
||||
if len(params) == 1:
|
||||
return list(params) * num_features
|
||||
return params
|
||||
|
||||
def generate_cell_anchors(self, sizes, aspect_ratios):
|
||||
anchors = []
|
||||
for size in sizes:
|
||||
area = size**2.0
|
||||
for aspect_ratio in aspect_ratios:
|
||||
w = math.sqrt(area / aspect_ratio)
|
||||
h = aspect_ratio * w
|
||||
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
|
||||
anchors.append([x0, y0, x1, y1])
|
||||
return paddle.to_tensor(anchors, dtype='float32')
|
||||
|
||||
def _calculate_anchors(self, num_features):
|
||||
sizes = self._broadcast_params(self.anchor_sizes, num_features)
|
||||
aspect_ratios = self._broadcast_params(self.aspect_ratios, num_features)
|
||||
cell_anchors = [
|
||||
self.generate_cell_anchors(s, a)
|
||||
for s, a in zip(sizes, aspect_ratios)
|
||||
]
|
||||
[
|
||||
self.register_buffer(
|
||||
t.name, t, persistable=False) for t in cell_anchors
|
||||
]
|
||||
return cell_anchors
|
||||
|
||||
def _create_grid_offsets(self, size, stride, offset):
|
||||
grid_height, grid_width = size[0], size[1]
|
||||
shifts_x = paddle.arange(
|
||||
offset * stride, grid_width * stride, step=stride, dtype='float32')
|
||||
shifts_y = paddle.arange(
|
||||
offset * stride, grid_height * stride, step=stride, dtype='float32')
|
||||
shift_y, shift_x = paddle.meshgrid(shifts_y, shifts_x)
|
||||
shift_x = paddle.reshape(shift_x, [-1])
|
||||
shift_y = paddle.reshape(shift_y, [-1])
|
||||
return shift_x, shift_y
|
||||
|
||||
def _grid_anchors(self, grid_sizes):
|
||||
anchors = []
|
||||
for size, stride, base_anchors in zip(grid_sizes, self.strides,
|
||||
self.cell_anchors):
|
||||
shift_x, shift_y = self._create_grid_offsets(size, stride,
|
||||
self.offset)
|
||||
shifts = paddle.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
|
||||
shifts = paddle.reshape(shifts, [-1, 1, 4])
|
||||
base_anchors = paddle.reshape(base_anchors, [1, -1, 4])
|
||||
|
||||
anchors.append(paddle.reshape(shifts + base_anchors, [-1, 4]))
|
||||
|
||||
return anchors
|
||||
|
||||
def forward(self, input):
|
||||
grid_sizes = [paddle.shape(feature_map)[-2:] for feature_map in input]
|
||||
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
|
||||
return anchors_over_all_feature_maps
|
||||
|
||||
@property
|
||||
def num_anchors(self):
|
||||
"""
|
||||
Returns:
|
||||
int: number of anchors at every pixel
|
||||
location, on that feature map.
|
||||
For example, if at every pixel we use anchors of 3 aspect
|
||||
ratios and 5 sizes, the number of anchors is 15.
|
||||
For FPN models, `num_anchors` on every feature map is the same.
|
||||
"""
|
||||
return len(self.cell_anchors[0])
|
||||
|
||||
|
||||
@register
|
||||
class RetinaAnchorGenerator(AnchorGenerator):
|
||||
def __init__(self,
|
||||
octave_base_scale=4,
|
||||
scales_per_octave=3,
|
||||
aspect_ratios=[0.5, 1.0, 2.0],
|
||||
strides=[8.0, 16.0, 32.0, 64.0, 128.0],
|
||||
variance=[1.0, 1.0, 1.0, 1.0],
|
||||
offset=0.0):
|
||||
anchor_sizes = []
|
||||
for s in strides:
|
||||
anchor_sizes.append([
|
||||
s * octave_base_scale * 2**(i/scales_per_octave) \
|
||||
for i in range(scales_per_octave)])
|
||||
super(RetinaAnchorGenerator, self).__init__(
|
||||
anchor_sizes=anchor_sizes,
|
||||
aspect_ratios=aspect_ratios,
|
||||
strides=strides,
|
||||
variance=variance,
|
||||
offset=offset)
|
||||
|
||||
|
||||
@register
|
||||
class S2ANetAnchorGenerator(nn.Layer):
|
||||
"""
|
||||
AnchorGenerator by paddle
|
||||
"""
|
||||
|
||||
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
|
||||
super(S2ANetAnchorGenerator, self).__init__()
|
||||
self.base_size = base_size
|
||||
self.scales = paddle.to_tensor(scales)
|
||||
self.ratios = paddle.to_tensor(ratios)
|
||||
self.scale_major = scale_major
|
||||
self.ctr = ctr
|
||||
self.base_anchors = self.gen_base_anchors()
|
||||
|
||||
@property
|
||||
def num_base_anchors(self):
|
||||
return self.base_anchors.shape[0]
|
||||
|
||||
def gen_base_anchors(self):
|
||||
w = self.base_size
|
||||
h = self.base_size
|
||||
if self.ctr is None:
|
||||
x_ctr = 0.5 * (w - 1)
|
||||
y_ctr = 0.5 * (h - 1)
|
||||
else:
|
||||
x_ctr, y_ctr = self.ctr
|
||||
|
||||
h_ratios = paddle.sqrt(self.ratios)
|
||||
w_ratios = 1 / h_ratios
|
||||
if self.scale_major:
|
||||
ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
|
||||
hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
|
||||
else:
|
||||
ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
|
||||
hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
|
||||
|
||||
base_anchors = paddle.stack(
|
||||
[
|
||||
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
|
||||
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
|
||||
],
|
||||
axis=-1)
|
||||
base_anchors = paddle.round(base_anchors)
|
||||
return base_anchors
|
||||
|
||||
def _meshgrid(self, x, y, row_major=True):
|
||||
yy, xx = paddle.meshgrid(y, x)
|
||||
yy = yy.reshape([-1])
|
||||
xx = xx.reshape([-1])
|
||||
if row_major:
|
||||
return xx, yy
|
||||
else:
|
||||
return yy, xx
|
||||
|
||||
def forward(self, featmap_size, stride=16):
|
||||
# featmap_size*stride project it to original area
|
||||
|
||||
feat_h = featmap_size[0]
|
||||
feat_w = featmap_size[1]
|
||||
shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
|
||||
shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
|
||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||
shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
|
||||
|
||||
all_anchors = self.base_anchors[:, :] + shifts[:, :]
|
||||
all_anchors = all_anchors.cast(paddle.float32).reshape(
|
||||
[feat_h * feat_w, 4])
|
||||
all_anchors = self.rect2rbox(all_anchors)
|
||||
return all_anchors
|
||||
|
||||
def valid_flags(self, featmap_size, valid_size):
|
||||
feat_h, feat_w = featmap_size
|
||||
valid_h, valid_w = valid_size
|
||||
assert valid_h <= feat_h and valid_w <= feat_w
|
||||
valid_x = paddle.zeros([feat_w], dtype='int32')
|
||||
valid_y = paddle.zeros([feat_h], dtype='int32')
|
||||
valid_x[:valid_w] = 1
|
||||
valid_y[:valid_h] = 1
|
||||
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
||||
valid = valid_xx & valid_yy
|
||||
valid = paddle.reshape(valid, [-1, 1])
|
||||
valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1])
|
||||
return valid
|
||||
|
||||
def rect2rbox(self, bboxes):
|
||||
"""
|
||||
:param bboxes: shape (L, 4) (xmin, ymin, xmax, ymax)
|
||||
:return: dbboxes: shape (L, 5) (x_ctr, y_ctr, w, h, angle)
|
||||
"""
|
||||
x1, y1, x2, y2 = paddle.split(bboxes, 4, axis=-1)
|
||||
|
||||
x_ctr = (x1 + x2) / 2.0
|
||||
y_ctr = (y1 + y2) / 2.0
|
||||
edges1 = paddle.abs(x2 - x1)
|
||||
edges2 = paddle.abs(y2 - y1)
|
||||
|
||||
rbox_w = paddle.maximum(edges1, edges2)
|
||||
rbox_h = paddle.minimum(edges1, edges2)
|
||||
|
||||
# set angle
|
||||
inds = edges1 < edges2
|
||||
inds = paddle.cast(inds, paddle.float32)
|
||||
rboxes_angle = inds * np.pi / 2.0
|
||||
|
||||
rboxes = paddle.concat(
|
||||
(x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=-1)
|
||||
return rboxes
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This code is referenced from: https://github.com/open-mmlab/mmdetection
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from ppdet.core.workspace import register
|
||||
|
||||
__all__ = ['EmbeddingRPNHead']
|
||||
|
||||
|
||||
@register
|
||||
class EmbeddingRPNHead(nn.Layer):
|
||||
__shared__ = ['proposal_embedding_dim']
|
||||
|
||||
def __init__(self, num_proposals, proposal_embedding_dim=256):
|
||||
super(EmbeddingRPNHead, self).__init__()
|
||||
|
||||
self.num_proposals = num_proposals
|
||||
self.proposal_embedding_dim = proposal_embedding_dim
|
||||
|
||||
self._init_layers()
|
||||
self._init_weights()
|
||||
|
||||
def _init_layers(self):
|
||||
self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
|
||||
self.init_proposal_features = nn.Embedding(self.num_proposals,
|
||||
self.proposal_embedding_dim)
|
||||
|
||||
def _init_weights(self):
|
||||
init_bboxes = paddle.empty_like(self.init_proposal_bboxes.weight)
|
||||
init_bboxes[:, :2] = 0.5
|
||||
init_bboxes[:, 2:] = 1.0
|
||||
self.init_proposal_bboxes.weight.set_value(init_bboxes)
|
||||
|
||||
@staticmethod
|
||||
def bbox_cxcywh_to_xyxy(x):
|
||||
cxcy, wh = paddle.split(x, 2, axis=-1)
|
||||
return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)
|
||||
|
||||
def forward(self, img_whwh):
|
||||
proposal_bboxes = self.init_proposal_bboxes.weight.clone()
|
||||
proposal_bboxes = self.bbox_cxcywh_to_xyxy(proposal_bboxes)
|
||||
proposal_bboxes = proposal_bboxes.unsqueeze(0) * img_whwh.unsqueeze(1)
|
||||
|
||||
proposal_features = self.init_proposal_features.weight.clone()
|
||||
proposal_features = proposal_features.unsqueeze(0).tile(
|
||||
[img_whwh.shape[0], 1, 1])
|
||||
|
||||
return proposal_bboxes, proposal_features
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from .. import ops
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class ProposalGenerator(object):
|
||||
"""
|
||||
Proposal generation module
|
||||
|
||||
For more details, please refer to the document of generate_proposals
|
||||
in ppdet/modeing/ops.py
|
||||
|
||||
Args:
|
||||
pre_nms_top_n (int): Number of total bboxes to be kept per
|
||||
image before NMS. default 6000
|
||||
post_nms_top_n (int): Number of total bboxes to be kept per
|
||||
image after NMS. default 1000
|
||||
nms_thresh (float): Threshold in NMS. default 0.5
|
||||
min_size (flaot): Remove predicted boxes with either height or
|
||||
width < min_size. default 0.1
|
||||
eta (float): Apply in adaptive NMS, if adaptive `threshold > 0.5`,
|
||||
`adaptive_threshold = adaptive_threshold * eta` in each iteration.
|
||||
default 1.
|
||||
topk_after_collect (bool): whether to adopt topk after batch
|
||||
collection. If topk_after_collect is true, box filter will not be
|
||||
used after NMS at each image in proposal generation. default false
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pre_nms_top_n=12000,
|
||||
post_nms_top_n=2000,
|
||||
nms_thresh=.5,
|
||||
min_size=.1,
|
||||
eta=1.,
|
||||
topk_after_collect=False):
|
||||
super(ProposalGenerator, self).__init__()
|
||||
self.pre_nms_top_n = pre_nms_top_n
|
||||
self.post_nms_top_n = post_nms_top_n
|
||||
self.nms_thresh = nms_thresh
|
||||
self.min_size = min_size
|
||||
self.eta = eta
|
||||
self.topk_after_collect = topk_after_collect
|
||||
|
||||
def __call__(self, scores, bbox_deltas, anchors, im_shape):
|
||||
|
||||
top_n = self.pre_nms_top_n if self.topk_after_collect else self.post_nms_top_n
|
||||
variances = paddle.ones_like(anchors)
|
||||
if hasattr(paddle.vision.ops, "generate_proposals"):
|
||||
generate_proposals = getattr(paddle.vision.ops,
|
||||
"generate_proposals")
|
||||
else:
|
||||
generate_proposals = ops.generate_proposals
|
||||
rpn_rois, rpn_rois_prob, rpn_rois_num = generate_proposals(
|
||||
scores,
|
||||
bbox_deltas,
|
||||
im_shape,
|
||||
anchors,
|
||||
variances,
|
||||
pre_nms_top_n=self.pre_nms_top_n,
|
||||
post_nms_top_n=top_n,
|
||||
nms_thresh=self.nms_thresh,
|
||||
min_size=self.min_size,
|
||||
eta=self.eta,
|
||||
return_rois_num=True)
|
||||
|
||||
return rpn_rois, rpn_rois_prob, rpn_rois_num, self.post_nms_top_n
|
||||
313
paddle_detection/ppdet/modeling/proposal_generator/rpn_head.py
Normal file
313
paddle_detection/ppdet/modeling/proposal_generator/rpn_head.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import Normal
|
||||
|
||||
from ppdet.core.workspace import register
|
||||
from .anchor_generator import AnchorGenerator
|
||||
from .target_layer import RPNTargetAssign
|
||||
from .proposal_generator import ProposalGenerator
|
||||
from ..cls_utils import _get_class_default_kwargs
|
||||
|
||||
|
||||
class RPNFeat(nn.Layer):
|
||||
"""
|
||||
Feature extraction in RPN head
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel
|
||||
out_channel (int): Output channel
|
||||
"""
|
||||
|
||||
def __init__(self, in_channel=1024, out_channel=1024):
|
||||
super(RPNFeat, self).__init__()
|
||||
# rpn feat is shared with each level
|
||||
self.rpn_conv = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=paddle.ParamAttr(initializer=Normal(
|
||||
mean=0., std=0.01)))
|
||||
self.rpn_conv.skip_quant = True
|
||||
|
||||
def forward(self, feats):
|
||||
rpn_feats = []
|
||||
for feat in feats:
|
||||
rpn_feats.append(F.relu(self.rpn_conv(feat)))
|
||||
return rpn_feats
|
||||
|
||||
|
||||
@register
|
||||
class RPNHead(nn.Layer):
|
||||
"""
|
||||
Region Proposal Network
|
||||
|
||||
Args:
|
||||
anchor_generator (dict): configure of anchor generation
|
||||
rpn_target_assign (dict): configure of rpn targets assignment
|
||||
train_proposal (dict): configure of proposals generation
|
||||
at the stage of training
|
||||
test_proposal (dict): configure of proposals generation
|
||||
at the stage of prediction
|
||||
in_channel (int): channel of input feature maps which can be
|
||||
derived by from_config
|
||||
"""
|
||||
__shared__ = ['export_onnx']
|
||||
__inject__ = ['loss_rpn_bbox']
|
||||
|
||||
def __init__(self,
|
||||
anchor_generator=_get_class_default_kwargs(AnchorGenerator),
|
||||
rpn_target_assign=_get_class_default_kwargs(RPNTargetAssign),
|
||||
train_proposal=_get_class_default_kwargs(ProposalGenerator,
|
||||
12000, 2000),
|
||||
test_proposal=_get_class_default_kwargs(ProposalGenerator),
|
||||
in_channel=1024,
|
||||
export_onnx=False,
|
||||
loss_rpn_bbox=None):
|
||||
super(RPNHead, self).__init__()
|
||||
self.anchor_generator = anchor_generator
|
||||
self.rpn_target_assign = rpn_target_assign
|
||||
self.train_proposal = train_proposal
|
||||
self.test_proposal = test_proposal
|
||||
self.export_onnx = export_onnx
|
||||
if isinstance(anchor_generator, dict):
|
||||
self.anchor_generator = AnchorGenerator(**anchor_generator)
|
||||
if isinstance(rpn_target_assign, dict):
|
||||
self.rpn_target_assign = RPNTargetAssign(**rpn_target_assign)
|
||||
if isinstance(train_proposal, dict):
|
||||
self.train_proposal = ProposalGenerator(**train_proposal)
|
||||
if isinstance(test_proposal, dict):
|
||||
self.test_proposal = ProposalGenerator(**test_proposal)
|
||||
self.loss_rpn_bbox = loss_rpn_bbox
|
||||
|
||||
num_anchors = self.anchor_generator.num_anchors
|
||||
self.rpn_feat = RPNFeat(in_channel, in_channel)
|
||||
# rpn head is shared with each level
|
||||
# rpn roi classification scores
|
||||
self.rpn_rois_score = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=num_anchors,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
weight_attr=paddle.ParamAttr(initializer=Normal(
|
||||
mean=0., std=0.01)))
|
||||
self.rpn_rois_score.skip_quant = True
|
||||
|
||||
# rpn roi bbox regression deltas
|
||||
self.rpn_rois_delta = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=4 * num_anchors,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
weight_attr=paddle.ParamAttr(initializer=Normal(
|
||||
mean=0., std=0.01)))
|
||||
self.rpn_rois_delta.skip_quant = True
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg, input_shape):
|
||||
# FPN share same rpn head
|
||||
if isinstance(input_shape, (list, tuple)):
|
||||
input_shape = input_shape[0]
|
||||
return {'in_channel': input_shape.channels}
|
||||
|
||||
def forward(self, feats, inputs):
|
||||
rpn_feats = self.rpn_feat(feats)
|
||||
scores = []
|
||||
deltas = []
|
||||
|
||||
for rpn_feat in rpn_feats:
|
||||
rrs = self.rpn_rois_score(rpn_feat)
|
||||
rrd = self.rpn_rois_delta(rpn_feat)
|
||||
scores.append(rrs)
|
||||
deltas.append(rrd)
|
||||
|
||||
anchors = self.anchor_generator(rpn_feats)
|
||||
|
||||
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs)
|
||||
if self.training:
|
||||
loss = self.get_loss(scores, deltas, anchors, inputs)
|
||||
return rois, rois_num, loss
|
||||
else:
|
||||
return rois, rois_num, None
|
||||
|
||||
def _gen_proposal(self, scores, bbox_deltas, anchors, inputs):
|
||||
"""
|
||||
scores (list[Tensor]): Multi-level scores prediction
|
||||
bbox_deltas (list[Tensor]): Multi-level deltas prediction
|
||||
anchors (list[Tensor]): Multi-level anchors
|
||||
inputs (dict): ground truth info
|
||||
"""
|
||||
prop_gen = self.train_proposal if self.training else self.test_proposal
|
||||
im_shape = inputs['im_shape']
|
||||
|
||||
# Collect multi-level proposals for each batch
|
||||
# Get 'topk' of them as final output
|
||||
|
||||
if self.export_onnx:
|
||||
# bs = 1 when exporting onnx
|
||||
onnx_rpn_rois_list = []
|
||||
onnx_rpn_prob_list = []
|
||||
onnx_rpn_rois_num_list = []
|
||||
|
||||
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
|
||||
anchors):
|
||||
onnx_rpn_rois, onnx_rpn_rois_prob, onnx_rpn_rois_num, onnx_post_nms_top_n = prop_gen(
|
||||
scores=rpn_score[0:1],
|
||||
bbox_deltas=rpn_delta[0:1],
|
||||
anchors=anchor,
|
||||
im_shape=im_shape[0:1])
|
||||
onnx_rpn_rois_list.append(onnx_rpn_rois)
|
||||
onnx_rpn_prob_list.append(onnx_rpn_rois_prob)
|
||||
onnx_rpn_rois_num_list.append(onnx_rpn_rois_num)
|
||||
|
||||
onnx_rpn_rois = paddle.concat(onnx_rpn_rois_list)
|
||||
onnx_rpn_prob = paddle.concat(onnx_rpn_prob_list).flatten()
|
||||
|
||||
onnx_top_n = paddle.to_tensor(onnx_post_nms_top_n).cast('int32')
|
||||
onnx_num_rois = paddle.shape(onnx_rpn_prob)[0].cast('int32')
|
||||
k = paddle.minimum(onnx_top_n, onnx_num_rois)
|
||||
onnx_topk_prob, onnx_topk_inds = paddle.topk(onnx_rpn_prob, k)
|
||||
onnx_topk_rois = paddle.gather(onnx_rpn_rois, onnx_topk_inds)
|
||||
# TODO(wangguanzhong): Now bs_rois_collect in export_onnx is moved outside conditional branch
|
||||
# due to problems in dy2static of paddle. Will fix it when updating paddle framework.
|
||||
# bs_rois_collect = [onnx_topk_rois]
|
||||
# bs_rois_num_collect = paddle.shape(onnx_topk_rois)[0]
|
||||
|
||||
else:
|
||||
bs_rois_collect = []
|
||||
bs_rois_num_collect = []
|
||||
|
||||
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
|
||||
|
||||
# Generate proposals for each level and each batch.
|
||||
# Discard batch-computing to avoid sorting bbox cross different batches.
|
||||
for i in range(batch_size):
|
||||
rpn_rois_list = []
|
||||
rpn_prob_list = []
|
||||
rpn_rois_num_list = []
|
||||
|
||||
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
|
||||
anchors):
|
||||
rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
|
||||
scores=rpn_score[i:i + 1],
|
||||
bbox_deltas=rpn_delta[i:i + 1],
|
||||
anchors=anchor,
|
||||
im_shape=im_shape[i:i + 1])
|
||||
rpn_rois_list.append(rpn_rois)
|
||||
rpn_prob_list.append(rpn_rois_prob)
|
||||
rpn_rois_num_list.append(rpn_rois_num)
|
||||
|
||||
if len(scores) > 1:
|
||||
rpn_rois = paddle.concat(rpn_rois_list)
|
||||
rpn_prob = paddle.concat(rpn_prob_list).flatten()
|
||||
|
||||
num_rois = paddle.shape(rpn_prob)[0].cast('int32')
|
||||
if num_rois > post_nms_top_n:
|
||||
topk_prob, topk_inds = paddle.topk(rpn_prob,
|
||||
post_nms_top_n)
|
||||
topk_rois = paddle.gather(rpn_rois, topk_inds)
|
||||
else:
|
||||
topk_rois = rpn_rois
|
||||
topk_prob = rpn_prob
|
||||
else:
|
||||
topk_rois = rpn_rois_list[0]
|
||||
topk_prob = rpn_prob_list[0].flatten()
|
||||
|
||||
bs_rois_collect.append(topk_rois)
|
||||
bs_rois_num_collect.append(paddle.shape(topk_rois)[0:1])
|
||||
|
||||
bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
|
||||
|
||||
if self.export_onnx:
|
||||
output_rois = [onnx_topk_rois]
|
||||
output_rois_num = paddle.shape(onnx_topk_rois)[0]
|
||||
else:
|
||||
output_rois = bs_rois_collect
|
||||
output_rois_num = bs_rois_num_collect
|
||||
|
||||
return output_rois, output_rois_num
|
||||
|
||||
def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
|
||||
"""
|
||||
pred_scores (list[Tensor]): Multi-level scores prediction
|
||||
pred_deltas (list[Tensor]): Multi-level deltas prediction
|
||||
anchors (list[Tensor]): Multi-level anchors
|
||||
inputs (dict): ground truth info, including im, gt_bbox, gt_score
|
||||
"""
|
||||
anchors = [paddle.reshape(a, shape=(-1, 4)) for a in anchors]
|
||||
anchors = paddle.concat(anchors)
|
||||
|
||||
scores = [
|
||||
paddle.reshape(
|
||||
paddle.transpose(
|
||||
v, perm=[0, 2, 3, 1]),
|
||||
shape=(v.shape[0], -1, 1)) for v in pred_scores
|
||||
]
|
||||
scores = paddle.concat(scores, axis=1)
|
||||
|
||||
deltas = [
|
||||
paddle.reshape(
|
||||
paddle.transpose(
|
||||
v, perm=[0, 2, 3, 1]),
|
||||
shape=(v.shape[0], -1, 4)) for v in pred_deltas
|
||||
]
|
||||
deltas = paddle.concat(deltas, axis=1)
|
||||
|
||||
score_tgt, bbox_tgt, loc_tgt, norm = self.rpn_target_assign(inputs,
|
||||
anchors)
|
||||
|
||||
scores = paddle.reshape(x=scores, shape=(-1, ))
|
||||
deltas = paddle.reshape(x=deltas, shape=(-1, 4))
|
||||
|
||||
score_tgt = paddle.concat(score_tgt)
|
||||
score_tgt.stop_gradient = True
|
||||
|
||||
pos_mask = score_tgt == 1
|
||||
pos_ind = paddle.nonzero(pos_mask)
|
||||
|
||||
valid_mask = score_tgt >= 0
|
||||
valid_ind = paddle.nonzero(valid_mask)
|
||||
|
||||
# cls loss
|
||||
if valid_ind.shape[0] == 0:
|
||||
loss_rpn_cls = paddle.zeros([1], dtype='float32')
|
||||
else:
|
||||
score_pred = paddle.gather(scores, valid_ind)
|
||||
score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
|
||||
score_label.stop_gradient = True
|
||||
loss_rpn_cls = F.binary_cross_entropy_with_logits(
|
||||
logit=score_pred, label=score_label, reduction="sum")
|
||||
|
||||
# reg loss
|
||||
if pos_ind.shape[0] == 0:
|
||||
loss_rpn_reg = paddle.zeros([1], dtype='float32')
|
||||
else:
|
||||
loc_pred = paddle.gather(deltas, pos_ind)
|
||||
loc_tgt = paddle.concat(loc_tgt)
|
||||
loc_tgt = paddle.gather(loc_tgt, pos_ind)
|
||||
loc_tgt.stop_gradient = True
|
||||
|
||||
if self.loss_rpn_bbox is None:
|
||||
loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum()
|
||||
else:
|
||||
loss_rpn_reg = self.loss_rpn_bbox(loc_pred, loc_tgt).sum()
|
||||
|
||||
return {
|
||||
'loss_rpn_cls': loss_rpn_cls / norm,
|
||||
'loss_rpn_reg': loss_rpn_reg / norm
|
||||
}
|
||||
678
paddle_detection/ppdet/modeling/proposal_generator/target.py
Normal file
678
paddle_detection/ppdet/modeling/proposal_generator/target.py
Normal file
@@ -0,0 +1,678 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from ..bbox_utils import bbox2delta, bbox_overlaps
|
||||
|
||||
|
||||
def rpn_anchor_target(anchors,
|
||||
gt_boxes,
|
||||
rpn_batch_size_per_im,
|
||||
rpn_positive_overlap,
|
||||
rpn_negative_overlap,
|
||||
rpn_fg_fraction,
|
||||
use_random=True,
|
||||
batch_size=1,
|
||||
ignore_thresh=-1,
|
||||
is_crowd=None,
|
||||
weights=[1., 1., 1., 1.],
|
||||
assign_on_cpu=False):
|
||||
tgt_labels = []
|
||||
tgt_bboxes = []
|
||||
tgt_deltas = []
|
||||
for i in range(batch_size):
|
||||
gt_bbox = gt_boxes[i]
|
||||
is_crowd_i = is_crowd[i] if is_crowd else None
|
||||
# Step1: match anchor and gt_bbox
|
||||
matches, match_labels = label_box(
|
||||
anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
|
||||
ignore_thresh, is_crowd_i, assign_on_cpu)
|
||||
# Step2: sample anchor
|
||||
fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
|
||||
rpn_fg_fraction, 0, use_random)
|
||||
# Fill with the ignore label (-1), then set positive and negative labels
|
||||
labels = paddle.full(match_labels.shape, -1, dtype='int32')
|
||||
if bg_inds.shape[0] > 0:
|
||||
labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
|
||||
if fg_inds.shape[0] > 0:
|
||||
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
|
||||
# Step3: make output
|
||||
if gt_bbox.shape[0] == 0:
|
||||
matched_gt_boxes = paddle.zeros([matches.shape[0], 4])
|
||||
tgt_delta = paddle.zeros([matches.shape[0], 4])
|
||||
else:
|
||||
matched_gt_boxes = paddle.gather(gt_bbox, matches)
|
||||
tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
|
||||
matched_gt_boxes.stop_gradient = True
|
||||
tgt_delta.stop_gradient = True
|
||||
labels.stop_gradient = True
|
||||
tgt_labels.append(labels)
|
||||
tgt_bboxes.append(matched_gt_boxes)
|
||||
tgt_deltas.append(tgt_delta)
|
||||
|
||||
return tgt_labels, tgt_bboxes, tgt_deltas
|
||||
|
||||
|
||||
def label_box(anchors,
|
||||
gt_boxes,
|
||||
positive_overlap,
|
||||
negative_overlap,
|
||||
allow_low_quality,
|
||||
ignore_thresh,
|
||||
is_crowd=None,
|
||||
assign_on_cpu=False):
|
||||
if assign_on_cpu:
|
||||
device = paddle.device.get_device()
|
||||
paddle.set_device("cpu")
|
||||
iou = bbox_overlaps(gt_boxes, anchors)
|
||||
paddle.set_device(device)
|
||||
|
||||
else:
|
||||
iou = bbox_overlaps(gt_boxes, anchors)
|
||||
n_gt = gt_boxes.shape[0]
|
||||
if n_gt == 0 or is_crowd is None:
|
||||
n_gt_crowd = 0
|
||||
else:
|
||||
n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
|
||||
if iou.shape[0] == 0 or n_gt_crowd == n_gt:
|
||||
# No truth, assign everything to background
|
||||
default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
|
||||
default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
|
||||
return default_matches, default_match_labels
|
||||
# if ignore_thresh > 0, remove anchor if it is closed to
|
||||
# one of the crowded ground-truth
|
||||
if n_gt_crowd > 0:
|
||||
N_a = anchors.shape[0]
|
||||
ones = paddle.ones([N_a])
|
||||
mask = is_crowd * ones
|
||||
|
||||
if ignore_thresh > 0:
|
||||
crowd_iou = iou * mask
|
||||
valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
|
||||
axis=0) > 0).cast('float32')
|
||||
iou = iou * (1 - valid) - valid
|
||||
|
||||
# ignore the iou between anchor and crowded ground-truth
|
||||
iou = iou * (1 - mask) - mask
|
||||
|
||||
matched_vals, matches = paddle.topk(iou, k=1, axis=0)
|
||||
match_labels = paddle.full(matches.shape, -1, dtype='int32')
|
||||
# set ignored anchor with iou = -1
|
||||
neg_cond = paddle.logical_and(matched_vals > -1,
|
||||
matched_vals < negative_overlap)
|
||||
match_labels = paddle.where(neg_cond,
|
||||
paddle.zeros_like(match_labels), match_labels)
|
||||
match_labels = paddle.where(matched_vals >= positive_overlap,
|
||||
paddle.ones_like(match_labels), match_labels)
|
||||
if allow_low_quality:
|
||||
highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
|
||||
pred_inds_with_highest_quality = paddle.logical_and(
|
||||
iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
|
||||
0, keepdim=True)
|
||||
match_labels = paddle.where(pred_inds_with_highest_quality > 0,
|
||||
paddle.ones_like(match_labels),
|
||||
match_labels)
|
||||
|
||||
matches = matches.flatten()
|
||||
match_labels = match_labels.flatten()
|
||||
|
||||
return matches, match_labels
|
||||
|
||||
|
||||
def subsample_labels(labels,
|
||||
num_samples,
|
||||
fg_fraction,
|
||||
bg_label=0,
|
||||
use_random=True):
|
||||
positive = paddle.nonzero(
|
||||
paddle.logical_and(labels != -1, labels != bg_label))
|
||||
negative = paddle.nonzero(labels == bg_label)
|
||||
|
||||
fg_num = int(num_samples * fg_fraction)
|
||||
fg_num = min(positive.numel(), fg_num)
|
||||
bg_num = num_samples - fg_num
|
||||
bg_num = min(negative.numel(), bg_num)
|
||||
if fg_num == 0 and bg_num == 0:
|
||||
fg_inds = paddle.zeros([0], dtype='int32')
|
||||
bg_inds = paddle.zeros([0], dtype='int32')
|
||||
return fg_inds, bg_inds
|
||||
|
||||
# randomly select positive and negative examples
|
||||
|
||||
negative = negative.cast('int32').flatten()
|
||||
bg_perm = paddle.randperm(negative.numel(), dtype='int32')
|
||||
bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
|
||||
if use_random:
|
||||
bg_inds = paddle.gather(negative, bg_perm)
|
||||
else:
|
||||
bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
|
||||
if fg_num == 0:
|
||||
fg_inds = paddle.zeros([0], dtype='int32')
|
||||
return fg_inds, bg_inds
|
||||
|
||||
positive = positive.cast('int32').flatten()
|
||||
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
|
||||
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
|
||||
if use_random:
|
||||
fg_inds = paddle.gather(positive, fg_perm)
|
||||
else:
|
||||
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
|
||||
|
||||
return fg_inds, bg_inds
|
||||
|
||||
|
||||
def generate_proposal_target(rpn_rois,
|
||||
gt_classes,
|
||||
gt_boxes,
|
||||
batch_size_per_im,
|
||||
fg_fraction,
|
||||
fg_thresh,
|
||||
bg_thresh,
|
||||
num_classes,
|
||||
ignore_thresh=-1.,
|
||||
is_crowd=None,
|
||||
use_random=True,
|
||||
is_cascade=False,
|
||||
cascade_iou=0.5,
|
||||
assign_on_cpu=False,
|
||||
add_gt_as_proposals=True):
|
||||
|
||||
rois_with_gt = []
|
||||
tgt_labels = []
|
||||
tgt_bboxes = []
|
||||
tgt_gt_inds = []
|
||||
new_rois_num = []
|
||||
|
||||
# In cascade rcnn, the threshold for foreground and background
|
||||
# is used from cascade_iou
|
||||
fg_thresh = cascade_iou if is_cascade else fg_thresh
|
||||
bg_thresh = cascade_iou if is_cascade else bg_thresh
|
||||
for i, rpn_roi in enumerate(rpn_rois):
|
||||
gt_bbox = gt_boxes[i]
|
||||
is_crowd_i = is_crowd[i] if is_crowd else None
|
||||
gt_class = paddle.squeeze(gt_classes[i], axis=-1)
|
||||
|
||||
# Concat RoIs and gt boxes except cascade rcnn or none gt
|
||||
if add_gt_as_proposals and gt_bbox.shape[0] > 0:
|
||||
bbox = paddle.concat([rpn_roi, gt_bbox])
|
||||
else:
|
||||
bbox = rpn_roi
|
||||
|
||||
# Step1: label bbox
|
||||
matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
|
||||
False, ignore_thresh, is_crowd_i,
|
||||
assign_on_cpu)
|
||||
# Step2: sample bbox
|
||||
sampled_inds, sampled_gt_classes = sample_bbox(
|
||||
matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
|
||||
num_classes, use_random, is_cascade)
|
||||
|
||||
# Step3: make output
|
||||
rois_per_image = bbox if is_cascade else paddle.gather(bbox,
|
||||
sampled_inds)
|
||||
sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
|
||||
sampled_inds)
|
||||
if gt_bbox.shape[0] > 0:
|
||||
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
|
||||
else:
|
||||
num = rois_per_image.shape[0]
|
||||
sampled_bbox = paddle.zeros([num, 4], dtype='float32')
|
||||
|
||||
rois_per_image.stop_gradient = True
|
||||
sampled_gt_ind.stop_gradient = True
|
||||
sampled_bbox.stop_gradient = True
|
||||
tgt_labels.append(sampled_gt_classes)
|
||||
tgt_bboxes.append(sampled_bbox)
|
||||
rois_with_gt.append(rois_per_image)
|
||||
tgt_gt_inds.append(sampled_gt_ind)
|
||||
new_rois_num.append(paddle.shape(sampled_inds)[0:1])
|
||||
new_rois_num = paddle.concat(new_rois_num)
|
||||
return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
|
||||
|
||||
|
||||
def sample_bbox(matches,
|
||||
match_labels,
|
||||
gt_classes,
|
||||
batch_size_per_im,
|
||||
fg_fraction,
|
||||
num_classes,
|
||||
use_random=True,
|
||||
is_cascade=False):
|
||||
|
||||
n_gt = gt_classes.shape[0]
|
||||
if n_gt == 0:
|
||||
# No truth, assign everything to background
|
||||
gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
|
||||
#return matches, match_labels + num_classes
|
||||
else:
|
||||
gt_classes = paddle.gather(gt_classes, matches)
|
||||
gt_classes = paddle.where(match_labels == 0,
|
||||
paddle.ones_like(gt_classes) * num_classes,
|
||||
gt_classes)
|
||||
gt_classes = paddle.where(match_labels == -1,
|
||||
paddle.ones_like(gt_classes) * -1, gt_classes)
|
||||
if is_cascade:
|
||||
index = paddle.arange(matches.shape[0])
|
||||
return index, gt_classes
|
||||
rois_per_image = int(batch_size_per_im)
|
||||
|
||||
fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
|
||||
num_classes, use_random)
|
||||
if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
|
||||
# fake output labeled with -1 when all boxes are neither
|
||||
# foreground nor background
|
||||
sampled_inds = paddle.zeros([1], dtype='int32')
|
||||
else:
|
||||
sampled_inds = paddle.concat([fg_inds, bg_inds])
|
||||
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
|
||||
return sampled_inds, sampled_gt_classes
|
||||
|
||||
|
||||
def polygons_to_mask(polygons, height, width):
|
||||
"""
|
||||
Convert the polygons to mask format
|
||||
|
||||
Args:
|
||||
polygons (list[ndarray]): each array has shape (Nx2,)
|
||||
height (int): mask height
|
||||
width (int): mask width
|
||||
Returns:
|
||||
ndarray: a bool mask of shape (height, width)
|
||||
"""
|
||||
import pycocotools.mask as mask_util
|
||||
assert len(polygons) > 0, "COCOAPI does not support empty polygons"
|
||||
rles = mask_util.frPyObjects(polygons, height, width)
|
||||
rle = mask_util.merge(rles)
|
||||
return mask_util.decode(rle).astype(np.bool_)
|
||||
|
||||
|
||||
def rasterize_polygons_within_box(poly, box, resolution):
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
polygons = [np.asarray(p, dtype=np.float64) for p in poly]
|
||||
for p in polygons:
|
||||
p[0::2] = p[0::2] - box[0]
|
||||
p[1::2] = p[1::2] - box[1]
|
||||
|
||||
ratio_h = resolution / max(h, 0.1)
|
||||
ratio_w = resolution / max(w, 0.1)
|
||||
|
||||
if ratio_h == ratio_w:
|
||||
for p in polygons:
|
||||
p *= ratio_h
|
||||
else:
|
||||
for p in polygons:
|
||||
p[0::2] *= ratio_w
|
||||
p[1::2] *= ratio_h
|
||||
|
||||
# 3. Rasterize the polygons with coco api
|
||||
mask = polygons_to_mask(polygons, resolution, resolution)
|
||||
mask = paddle.to_tensor(mask, dtype='int32')
|
||||
return mask
|
||||
|
||||
|
||||
def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
|
||||
num_classes, resolution):
|
||||
mask_rois = []
|
||||
mask_rois_num = []
|
||||
tgt_masks = []
|
||||
tgt_classes = []
|
||||
mask_index = []
|
||||
tgt_weights = []
|
||||
for k in range(len(rois)):
|
||||
labels_per_im = labels_int32[k]
|
||||
# select rois labeled with foreground
|
||||
fg_inds = paddle.nonzero(
|
||||
paddle.logical_and(labels_per_im != -1, labels_per_im !=
|
||||
num_classes))
|
||||
has_fg = True
|
||||
# generate fake roi if foreground is empty
|
||||
if fg_inds.numel() == 0:
|
||||
has_fg = False
|
||||
fg_inds = paddle.ones([1, 1], dtype='int64')
|
||||
inds_per_im = sampled_gt_inds[k]
|
||||
inds_per_im = paddle.gather(inds_per_im, fg_inds)
|
||||
|
||||
rois_per_im = rois[k]
|
||||
fg_rois = paddle.gather(rois_per_im, fg_inds)
|
||||
# Copy the foreground roi to cpu
|
||||
# to generate mask target with ground-truth
|
||||
boxes = fg_rois.numpy()
|
||||
gt_segms_per_im = gt_segms[k]
|
||||
|
||||
new_segm = []
|
||||
inds_per_im = inds_per_im.numpy()
|
||||
if len(gt_segms_per_im) > 0:
|
||||
for i in inds_per_im:
|
||||
new_segm.append(gt_segms_per_im[i])
|
||||
fg_inds_new = fg_inds.reshape([-1]).numpy()
|
||||
results = []
|
||||
if len(gt_segms_per_im) > 0:
|
||||
for j in range(fg_inds_new.shape[0]):
|
||||
results.append(
|
||||
rasterize_polygons_within_box(new_segm[j], boxes[j],
|
||||
resolution))
|
||||
else:
|
||||
results.append(paddle.ones([resolution, resolution], dtype='int32'))
|
||||
|
||||
fg_classes = paddle.gather(labels_per_im, fg_inds)
|
||||
weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
|
||||
if not has_fg:
|
||||
# now all sampled classes are background
|
||||
# which will cause error in loss calculation,
|
||||
# make fake classes with weight of 0.
|
||||
fg_classes = paddle.zeros([1], dtype='int32')
|
||||
weight = weight - 1
|
||||
tgt_mask = paddle.stack(results)
|
||||
tgt_mask.stop_gradient = True
|
||||
fg_rois.stop_gradient = True
|
||||
|
||||
mask_index.append(fg_inds)
|
||||
mask_rois.append(fg_rois)
|
||||
mask_rois_num.append(paddle.shape(fg_rois)[0:1])
|
||||
tgt_classes.append(fg_classes)
|
||||
tgt_masks.append(tgt_mask)
|
||||
tgt_weights.append(weight)
|
||||
|
||||
mask_index = paddle.concat(mask_index)
|
||||
mask_rois_num = paddle.concat(mask_rois_num)
|
||||
tgt_classes = paddle.concat(tgt_classes, axis=0)
|
||||
tgt_masks = paddle.concat(tgt_masks, axis=0)
|
||||
tgt_weights = paddle.concat(tgt_weights, axis=0)
|
||||
|
||||
return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
|
||||
|
||||
|
||||
def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
|
||||
if len(pos_inds) <= num_expected:
|
||||
return pos_inds
|
||||
else:
|
||||
unique_gt_inds = np.unique(max_classes[pos_inds])
|
||||
num_gts = len(unique_gt_inds)
|
||||
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
|
||||
|
||||
sampled_inds = []
|
||||
for i in unique_gt_inds:
|
||||
inds = np.nonzero(max_classes == i)[0]
|
||||
before_len = len(inds)
|
||||
inds = list(set(inds) & set(pos_inds))
|
||||
after_len = len(inds)
|
||||
if len(inds) > num_per_gt:
|
||||
inds = np.random.choice(inds, size=num_per_gt, replace=False)
|
||||
sampled_inds.extend(list(inds)) # combine as a new sampler
|
||||
if len(sampled_inds) < num_expected:
|
||||
num_extra = num_expected - len(sampled_inds)
|
||||
extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
|
||||
assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
|
||||
"sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
|
||||
len(sampled_inds), len(extra_inds), len(pos_inds))
|
||||
if len(extra_inds) > num_extra:
|
||||
extra_inds = np.random.choice(
|
||||
extra_inds, size=num_extra, replace=False)
|
||||
sampled_inds.extend(extra_inds.tolist())
|
||||
elif len(sampled_inds) > num_expected:
|
||||
sampled_inds = np.random.choice(
|
||||
sampled_inds, size=num_expected, replace=False)
|
||||
return paddle.to_tensor(sampled_inds)
|
||||
|
||||
|
||||
def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
|
||||
num_bins, bg_thresh):
|
||||
max_iou = max_overlaps.max()
|
||||
iou_interval = (max_iou - floor_thr) / num_bins
|
||||
per_num_expected = int(num_expected / num_bins)
|
||||
|
||||
sampled_inds = []
|
||||
for i in range(num_bins):
|
||||
start_iou = floor_thr + i * iou_interval
|
||||
end_iou = floor_thr + (i + 1) * iou_interval
|
||||
|
||||
tmp_set = set(
|
||||
np.where(
|
||||
np.logical_and(max_overlaps >= start_iou, max_overlaps <
|
||||
end_iou))[0])
|
||||
tmp_inds = list(tmp_set & full_set)
|
||||
|
||||
if len(tmp_inds) > per_num_expected:
|
||||
tmp_sampled_set = np.random.choice(
|
||||
tmp_inds, size=per_num_expected, replace=False)
|
||||
else:
|
||||
tmp_sampled_set = np.array(tmp_inds, dtype=np.int32)
|
||||
sampled_inds.append(tmp_sampled_set)
|
||||
|
||||
sampled_inds = np.concatenate(sampled_inds)
|
||||
if len(sampled_inds) < num_expected:
|
||||
num_extra = num_expected - len(sampled_inds)
|
||||
extra_inds = np.array(list(full_set - set(sampled_inds)))
|
||||
assert len(sampled_inds) + len(extra_inds) == len(full_set), \
|
||||
"sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
|
||||
len(sampled_inds), len(extra_inds), len(full_set))
|
||||
|
||||
if len(extra_inds) > num_extra:
|
||||
extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
|
||||
sampled_inds = np.concatenate([sampled_inds, extra_inds])
|
||||
|
||||
return sampled_inds
|
||||
|
||||
|
||||
def libra_sample_neg(max_overlaps,
|
||||
max_classes,
|
||||
neg_inds,
|
||||
num_expected,
|
||||
floor_thr=-1,
|
||||
floor_fraction=0,
|
||||
num_bins=3,
|
||||
bg_thresh=0.5):
|
||||
if len(neg_inds) <= num_expected:
|
||||
return neg_inds
|
||||
else:
|
||||
# balance sampling for negative samples
|
||||
neg_set = set(neg_inds.tolist())
|
||||
if floor_thr > 0:
|
||||
floor_set = set(
|
||||
np.where(
|
||||
np.logical_and(max_overlaps >= 0, max_overlaps < floor_thr))
|
||||
[0])
|
||||
iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
|
||||
elif floor_thr == 0:
|
||||
floor_set = set(np.where(max_overlaps == 0)[0])
|
||||
iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
|
||||
else:
|
||||
floor_set = set()
|
||||
iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
|
||||
floor_thr = 0
|
||||
|
||||
floor_neg_inds = list(floor_set & neg_set)
|
||||
iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
|
||||
|
||||
num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
|
||||
if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
|
||||
if num_bins >= 2:
|
||||
iou_sampled_inds = libra_sample_via_interval(
|
||||
max_overlaps,
|
||||
set(iou_sampling_neg_inds), num_expected_iou_sampling,
|
||||
floor_thr, num_bins, bg_thresh)
|
||||
else:
|
||||
iou_sampled_inds = np.random.choice(
|
||||
iou_sampling_neg_inds,
|
||||
size=num_expected_iou_sampling,
|
||||
replace=False)
|
||||
else:
|
||||
iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int32)
|
||||
num_expected_floor = num_expected - len(iou_sampled_inds)
|
||||
if len(floor_neg_inds) > num_expected_floor:
|
||||
sampled_floor_inds = np.random.choice(
|
||||
floor_neg_inds, size=num_expected_floor, replace=False)
|
||||
else:
|
||||
sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int32)
|
||||
sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
|
||||
if len(sampled_inds) < num_expected:
|
||||
num_extra = num_expected - len(sampled_inds)
|
||||
extra_inds = np.array(list(neg_set - set(sampled_inds)))
|
||||
if len(extra_inds) > num_extra:
|
||||
extra_inds = np.random.choice(
|
||||
extra_inds, size=num_extra, replace=False)
|
||||
sampled_inds = np.concatenate((sampled_inds, extra_inds))
|
||||
return paddle.to_tensor(sampled_inds)
|
||||
|
||||
|
||||
def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
|
||||
negative_overlap, num_classes):
|
||||
# TODO: use paddle API to speed up
|
||||
gt_classes = gt_classes.numpy()
|
||||
gt_overlaps = np.zeros((anchors.shape[0], num_classes))
|
||||
matches = np.zeros((anchors.shape[0]), dtype=np.int32)
|
||||
if len(gt_boxes) > 0:
|
||||
proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
|
||||
overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
|
||||
overlaps_max = proposal_to_gt_overlaps.max(axis=1)
|
||||
# Boxes which with non-zero overlap with gt boxes
|
||||
overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
|
||||
overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
|
||||
overlapped_boxes_ind]]
|
||||
|
||||
for idx in range(len(overlapped_boxes_ind)):
|
||||
gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
|
||||
idx]] = overlaps_max[overlapped_boxes_ind[idx]]
|
||||
matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
|
||||
overlapped_boxes_ind[idx]]
|
||||
|
||||
gt_overlaps = paddle.to_tensor(gt_overlaps)
|
||||
matches = paddle.to_tensor(matches)
|
||||
|
||||
matched_vals = paddle.max(gt_overlaps, axis=1)
|
||||
match_labels = paddle.full(matches.shape, -1, dtype='int32')
|
||||
match_labels = paddle.where(matched_vals < negative_overlap,
|
||||
paddle.zeros_like(match_labels), match_labels)
|
||||
match_labels = paddle.where(matched_vals >= positive_overlap,
|
||||
paddle.ones_like(match_labels), match_labels)
|
||||
|
||||
return matches, match_labels, matched_vals
|
||||
|
||||
|
||||
def libra_sample_bbox(matches,
|
||||
match_labels,
|
||||
matched_vals,
|
||||
gt_classes,
|
||||
batch_size_per_im,
|
||||
num_classes,
|
||||
fg_fraction,
|
||||
fg_thresh,
|
||||
bg_thresh,
|
||||
num_bins,
|
||||
use_random=True,
|
||||
is_cascade_rcnn=False):
|
||||
rois_per_image = int(batch_size_per_im)
|
||||
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
|
||||
bg_rois_per_im = rois_per_image - fg_rois_per_im
|
||||
|
||||
if is_cascade_rcnn:
|
||||
fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
|
||||
bg_inds = paddle.nonzero(matched_vals < bg_thresh)
|
||||
else:
|
||||
matched_vals_np = matched_vals.numpy()
|
||||
match_labels_np = match_labels.numpy()
|
||||
|
||||
# sample fg
|
||||
fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
|
||||
fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
|
||||
if (fg_inds.shape[0] > fg_nums) and use_random:
|
||||
fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
|
||||
fg_inds.numpy(), fg_rois_per_im)
|
||||
fg_inds = fg_inds[:fg_nums]
|
||||
|
||||
# sample bg
|
||||
bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
|
||||
bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
|
||||
if (bg_inds.shape[0] > bg_nums) and use_random:
|
||||
bg_inds = libra_sample_neg(
|
||||
matched_vals_np,
|
||||
match_labels_np,
|
||||
bg_inds.numpy(),
|
||||
bg_rois_per_im,
|
||||
num_bins=num_bins,
|
||||
bg_thresh=bg_thresh)
|
||||
bg_inds = bg_inds[:bg_nums]
|
||||
|
||||
sampled_inds = paddle.concat([fg_inds, bg_inds])
|
||||
|
||||
gt_classes = paddle.gather(gt_classes, matches)
|
||||
gt_classes = paddle.where(match_labels == 0,
|
||||
paddle.ones_like(gt_classes) * num_classes,
|
||||
gt_classes)
|
||||
gt_classes = paddle.where(match_labels == -1,
|
||||
paddle.ones_like(gt_classes) * -1, gt_classes)
|
||||
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
|
||||
|
||||
return sampled_inds, sampled_gt_classes
|
||||
|
||||
|
||||
def libra_generate_proposal_target(rpn_rois,
|
||||
gt_classes,
|
||||
gt_boxes,
|
||||
batch_size_per_im,
|
||||
fg_fraction,
|
||||
fg_thresh,
|
||||
bg_thresh,
|
||||
num_classes,
|
||||
use_random=True,
|
||||
is_cascade_rcnn=False,
|
||||
max_overlaps=None,
|
||||
num_bins=3):
|
||||
|
||||
rois_with_gt = []
|
||||
tgt_labels = []
|
||||
tgt_bboxes = []
|
||||
sampled_max_overlaps = []
|
||||
tgt_gt_inds = []
|
||||
new_rois_num = []
|
||||
|
||||
for i, rpn_roi in enumerate(rpn_rois):
|
||||
max_overlap = max_overlaps[i] if is_cascade_rcnn else None
|
||||
gt_bbox = gt_boxes[i]
|
||||
gt_class = paddle.squeeze(gt_classes[i], axis=-1)
|
||||
if is_cascade_rcnn:
|
||||
rpn_roi = filter_roi(rpn_roi, max_overlap)
|
||||
bbox = paddle.concat([rpn_roi, gt_bbox])
|
||||
|
||||
# Step1: label bbox
|
||||
matches, match_labels, matched_vals = libra_label_box(
|
||||
bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
|
||||
|
||||
# Step2: sample bbox
|
||||
sampled_inds, sampled_gt_classes = libra_sample_bbox(
|
||||
matches, match_labels, matched_vals, gt_class, batch_size_per_im,
|
||||
num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
|
||||
use_random, is_cascade_rcnn)
|
||||
|
||||
# Step3: make output
|
||||
rois_per_image = paddle.gather(bbox, sampled_inds)
|
||||
sampled_gt_ind = paddle.gather(matches, sampled_inds)
|
||||
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
|
||||
sampled_overlap = paddle.gather(matched_vals, sampled_inds)
|
||||
|
||||
rois_per_image.stop_gradient = True
|
||||
sampled_gt_ind.stop_gradient = True
|
||||
sampled_bbox.stop_gradient = True
|
||||
sampled_overlap.stop_gradient = True
|
||||
|
||||
tgt_labels.append(sampled_gt_classes)
|
||||
tgt_bboxes.append(sampled_bbox)
|
||||
rois_with_gt.append(rois_per_image)
|
||||
sampled_max_overlaps.append(sampled_overlap)
|
||||
tgt_gt_inds.append(sampled_gt_ind)
|
||||
new_rois_num.append(paddle.shape(sampled_inds)[0:1])
|
||||
new_rois_num = paddle.concat(new_rois_num)
|
||||
# rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
|
||||
return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
|
||||
@@ -0,0 +1,481 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import paddle
|
||||
from ppdet.core.workspace import register, serializable
|
||||
|
||||
from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target
|
||||
import numpy as np
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class RPNTargetAssign(object):
|
||||
__shared__ = ['assign_on_cpu']
|
||||
"""
|
||||
RPN targets assignment module
|
||||
|
||||
The assignment consists of three steps:
|
||||
1. Match anchor and ground-truth box, label the anchor with foreground
|
||||
or background sample
|
||||
2. Sample anchors to keep the properly ratio between foreground and
|
||||
background
|
||||
3. Generate the targets for classification and regression branch
|
||||
|
||||
|
||||
Args:
|
||||
batch_size_per_im (int): Total number of RPN samples per image.
|
||||
default 256
|
||||
fg_fraction (float): Fraction of anchors that is labeled
|
||||
foreground, default 0.5
|
||||
positive_overlap (float): Minimum overlap required between an anchor
|
||||
and ground-truth box for the (anchor, gt box) pair to be
|
||||
a foreground sample. default 0.7
|
||||
negative_overlap (float): Maximum overlap allowed between an anchor
|
||||
and ground-truth box for the (anchor, gt box) pair to be
|
||||
a background sample. default 0.3
|
||||
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
|
||||
if the value is larger than zero.
|
||||
use_random (bool): Use random sampling to choose foreground and
|
||||
background boxes, default true.
|
||||
assign_on_cpu (bool): In case the number of gt box is too large,
|
||||
compute IoU on CPU, default false.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size_per_im=256,
|
||||
fg_fraction=0.5,
|
||||
positive_overlap=0.7,
|
||||
negative_overlap=0.3,
|
||||
ignore_thresh=-1.,
|
||||
use_random=True,
|
||||
assign_on_cpu=False):
|
||||
super(RPNTargetAssign, self).__init__()
|
||||
self.batch_size_per_im = batch_size_per_im
|
||||
self.fg_fraction = fg_fraction
|
||||
self.positive_overlap = positive_overlap
|
||||
self.negative_overlap = negative_overlap
|
||||
self.ignore_thresh = ignore_thresh
|
||||
self.use_random = use_random
|
||||
self.assign_on_cpu = assign_on_cpu
|
||||
|
||||
def __call__(self, inputs, anchors):
|
||||
"""
|
||||
inputs: ground-truth instances.
|
||||
anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
|
||||
"""
|
||||
gt_boxes = inputs['gt_bbox']
|
||||
is_crowd = inputs.get('is_crowd', None)
|
||||
batch_size = len(gt_boxes)
|
||||
tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
|
||||
anchors,
|
||||
gt_boxes,
|
||||
self.batch_size_per_im,
|
||||
self.positive_overlap,
|
||||
self.negative_overlap,
|
||||
self.fg_fraction,
|
||||
self.use_random,
|
||||
batch_size,
|
||||
self.ignore_thresh,
|
||||
is_crowd,
|
||||
assign_on_cpu=self.assign_on_cpu)
|
||||
norm = self.batch_size_per_im * batch_size
|
||||
|
||||
return tgt_labels, tgt_bboxes, tgt_deltas, norm
|
||||
|
||||
|
||||
@register
|
||||
class BBoxAssigner(object):
|
||||
__shared__ = ['num_classes', 'assign_on_cpu']
|
||||
"""
|
||||
RCNN targets assignment module
|
||||
|
||||
The assignment consists of three steps:
|
||||
1. Match RoIs and ground-truth box, label the RoIs with foreground
|
||||
or background sample
|
||||
2. Sample anchors to keep the properly ratio between foreground and
|
||||
background
|
||||
3. Generate the targets for classification and regression branch
|
||||
|
||||
Args:
|
||||
batch_size_per_im (int): Total number of RoIs per image.
|
||||
default 512
|
||||
fg_fraction (float): Fraction of RoIs that is labeled
|
||||
foreground, default 0.25
|
||||
fg_thresh (float): Minimum overlap required between a RoI
|
||||
and ground-truth box for the (roi, gt box) pair to be
|
||||
a foreground sample. default 0.5
|
||||
bg_thresh (float): Maximum overlap allowed between a RoI
|
||||
and ground-truth box for the (roi, gt box) pair to be
|
||||
a background sample. default 0.5
|
||||
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
|
||||
if the value is larger than zero.
|
||||
use_random (bool): Use random sampling to choose foreground and
|
||||
background boxes, default true
|
||||
cascade_iou (list[iou]): The list of overlap to select foreground and
|
||||
background of each stage, which is only used In Cascade RCNN.
|
||||
num_classes (int): The number of class.
|
||||
assign_on_cpu (bool): In case the number of gt box is too large,
|
||||
compute IoU on CPU, default false.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size_per_im=512,
|
||||
fg_fraction=.25,
|
||||
fg_thresh=.5,
|
||||
bg_thresh=.5,
|
||||
ignore_thresh=-1.,
|
||||
use_random=True,
|
||||
cascade_iou=[0.5, 0.6, 0.7],
|
||||
num_classes=80,
|
||||
assign_on_cpu=False):
|
||||
super(BBoxAssigner, self).__init__()
|
||||
self.batch_size_per_im = batch_size_per_im
|
||||
self.fg_fraction = fg_fraction
|
||||
self.fg_thresh = fg_thresh
|
||||
self.bg_thresh = bg_thresh
|
||||
self.ignore_thresh = ignore_thresh
|
||||
self.use_random = use_random
|
||||
self.cascade_iou = cascade_iou
|
||||
self.num_classes = num_classes
|
||||
self.assign_on_cpu = assign_on_cpu
|
||||
|
||||
def __call__(self,
|
||||
rpn_rois,
|
||||
rpn_rois_num,
|
||||
inputs,
|
||||
stage=0,
|
||||
is_cascade=False,
|
||||
add_gt_as_proposals=True):
|
||||
gt_classes = inputs['gt_class']
|
||||
gt_boxes = inputs['gt_bbox']
|
||||
is_crowd = inputs.get('is_crowd', None)
|
||||
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds
|
||||
# new_rois_num
|
||||
outs = generate_proposal_target(
|
||||
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
|
||||
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
|
||||
self.ignore_thresh, is_crowd, self.use_random, is_cascade,
|
||||
self.cascade_iou[stage], self.assign_on_cpu, add_gt_as_proposals)
|
||||
rois = outs[0]
|
||||
rois_num = outs[-1]
|
||||
# tgt_labels, tgt_bboxes, tgt_gt_inds
|
||||
targets = outs[1:4]
|
||||
return rois, rois_num, targets
|
||||
|
||||
|
||||
@register
|
||||
class BBoxLibraAssigner(object):
|
||||
__shared__ = ['num_classes']
|
||||
"""
|
||||
Libra-RCNN targets assignment module
|
||||
|
||||
The assignment consists of three steps:
|
||||
1. Match RoIs and ground-truth box, label the RoIs with foreground
|
||||
or background sample
|
||||
2. Sample anchors to keep the properly ratio between foreground and
|
||||
background
|
||||
3. Generate the targets for classification and regression branch
|
||||
|
||||
Args:
|
||||
batch_size_per_im (int): Total number of RoIs per image.
|
||||
default 512
|
||||
fg_fraction (float): Fraction of RoIs that is labeled
|
||||
foreground, default 0.25
|
||||
fg_thresh (float): Minimum overlap required between a RoI
|
||||
and ground-truth box for the (roi, gt box) pair to be
|
||||
a foreground sample. default 0.5
|
||||
bg_thresh (float): Maximum overlap allowed between a RoI
|
||||
and ground-truth box for the (roi, gt box) pair to be
|
||||
a background sample. default 0.5
|
||||
use_random (bool): Use random sampling to choose foreground and
|
||||
background boxes, default true
|
||||
cascade_iou (list[iou]): The list of overlap to select foreground and
|
||||
background of each stage, which is only used In Cascade RCNN.
|
||||
num_classes (int): The number of class.
|
||||
num_bins (int): The number of libra_sample.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size_per_im=512,
|
||||
fg_fraction=.25,
|
||||
fg_thresh=.5,
|
||||
bg_thresh=.5,
|
||||
use_random=True,
|
||||
cascade_iou=[0.5, 0.6, 0.7],
|
||||
num_classes=80,
|
||||
num_bins=3):
|
||||
super(BBoxLibraAssigner, self).__init__()
|
||||
self.batch_size_per_im = batch_size_per_im
|
||||
self.fg_fraction = fg_fraction
|
||||
self.fg_thresh = fg_thresh
|
||||
self.bg_thresh = bg_thresh
|
||||
self.use_random = use_random
|
||||
self.cascade_iou = cascade_iou
|
||||
self.num_classes = num_classes
|
||||
self.num_bins = num_bins
|
||||
|
||||
def __call__(self,
|
||||
rpn_rois,
|
||||
rpn_rois_num,
|
||||
inputs,
|
||||
stage=0,
|
||||
is_cascade=False):
|
||||
gt_classes = inputs['gt_class']
|
||||
gt_boxes = inputs['gt_bbox']
|
||||
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds
|
||||
outs = libra_generate_proposal_target(
|
||||
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
|
||||
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
|
||||
self.use_random, is_cascade, self.cascade_iou[stage], self.num_bins)
|
||||
rois = outs[0]
|
||||
rois_num = outs[-1]
|
||||
# tgt_labels, tgt_bboxes, tgt_gt_inds
|
||||
targets = outs[1:4]
|
||||
return rois, rois_num, targets
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class MaskAssigner(object):
|
||||
__shared__ = ['num_classes', 'mask_resolution']
|
||||
"""
|
||||
Mask targets assignment module
|
||||
|
||||
The assignment consists of three steps:
|
||||
1. Select RoIs labels with foreground.
|
||||
2. Encode the RoIs and corresponding gt polygons to generate
|
||||
mask target
|
||||
|
||||
Args:
|
||||
num_classes (int): The number of class
|
||||
mask_resolution (int): The resolution of mask target, default 14
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=80, mask_resolution=14):
|
||||
super(MaskAssigner, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.mask_resolution = mask_resolution
|
||||
|
||||
def __call__(self, rois, tgt_labels, tgt_gt_inds, inputs):
|
||||
gt_segms = inputs['gt_poly']
|
||||
|
||||
outs = generate_mask_target(gt_segms, rois, tgt_labels, tgt_gt_inds,
|
||||
self.num_classes, self.mask_resolution)
|
||||
|
||||
# mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
|
||||
return outs
|
||||
|
||||
|
||||
@register
|
||||
class RBoxAssigner(object):
|
||||
"""
|
||||
assigner of rbox
|
||||
Args:
|
||||
pos_iou_thr (float): threshold of pos samples
|
||||
neg_iou_thr (float): threshold of neg samples
|
||||
min_iou_thr (float): the min threshold of samples
|
||||
ignore_iof_thr (int): the ignored threshold
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pos_iou_thr=0.5,
|
||||
neg_iou_thr=0.4,
|
||||
min_iou_thr=0.0,
|
||||
ignore_iof_thr=-2):
|
||||
super(RBoxAssigner, self).__init__()
|
||||
|
||||
self.pos_iou_thr = pos_iou_thr
|
||||
self.neg_iou_thr = neg_iou_thr
|
||||
self.min_iou_thr = min_iou_thr
|
||||
self.ignore_iof_thr = ignore_iof_thr
|
||||
|
||||
def anchor_valid(self, anchors):
|
||||
"""
|
||||
|
||||
Args:
|
||||
anchor: M x 4
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if anchors.ndim == 3:
|
||||
anchors = anchors.reshape(-1, anchors.shape[-1])
|
||||
assert anchors.ndim == 2
|
||||
anchor_num = anchors.shape[0]
|
||||
anchor_valid = np.ones((anchor_num), np.int32)
|
||||
anchor_inds = np.arange(anchor_num)
|
||||
return anchor_inds
|
||||
|
||||
def rbox2delta(self,
|
||||
proposals,
|
||||
gt,
|
||||
means=[0, 0, 0, 0, 0],
|
||||
stds=[1, 1, 1, 1, 1]):
|
||||
"""
|
||||
Args:
|
||||
proposals: tensor [N, 5]
|
||||
gt: gt [N, 5]
|
||||
means: means [5]
|
||||
stds: stds [5]
|
||||
Returns:
|
||||
|
||||
"""
|
||||
proposals = proposals.astype(np.float64)
|
||||
|
||||
PI = np.pi
|
||||
|
||||
gt_widths = gt[..., 2]
|
||||
gt_heights = gt[..., 3]
|
||||
gt_angle = gt[..., 4]
|
||||
|
||||
proposals_widths = proposals[..., 2]
|
||||
proposals_heights = proposals[..., 3]
|
||||
proposals_angle = proposals[..., 4]
|
||||
|
||||
coord = gt[..., 0:2] - proposals[..., 0:2]
|
||||
dx = (np.cos(proposals[..., 4]) * coord[..., 0] +
|
||||
np.sin(proposals[..., 4]) * coord[..., 1]) / proposals_widths
|
||||
dy = (-np.sin(proposals[..., 4]) * coord[..., 0] +
|
||||
np.cos(proposals[..., 4]) * coord[..., 1]) / proposals_heights
|
||||
dw = np.log(gt_widths / proposals_widths)
|
||||
dh = np.log(gt_heights / proposals_heights)
|
||||
da = (gt_angle - proposals_angle)
|
||||
|
||||
da = (da + PI / 4) % PI - PI / 4
|
||||
da /= PI
|
||||
|
||||
deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
|
||||
means = np.array(means, dtype=deltas.dtype)
|
||||
stds = np.array(stds, dtype=deltas.dtype)
|
||||
deltas = (deltas - means) / stds
|
||||
deltas = deltas.astype(np.float32)
|
||||
return deltas
|
||||
|
||||
def assign_anchor(self,
|
||||
anchors,
|
||||
gt_bboxes,
|
||||
gt_labels,
|
||||
pos_iou_thr,
|
||||
neg_iou_thr,
|
||||
min_iou_thr=0.0,
|
||||
ignore_iof_thr=-2):
|
||||
assert anchors.shape[1] == 4 or anchors.shape[1] == 5
|
||||
assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5
|
||||
anchors_xc_yc = anchors
|
||||
gt_bboxes_xc_yc = gt_bboxes
|
||||
|
||||
# calc rbox iou
|
||||
anchors_xc_yc = anchors_xc_yc.astype(np.float32)
|
||||
gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32)
|
||||
anchors_xc_yc = paddle.to_tensor(anchors_xc_yc)
|
||||
gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc)
|
||||
|
||||
try:
|
||||
from ext_op import rbox_iou
|
||||
except Exception as e:
|
||||
print("import custom_ops error, try install ext_op " \
|
||||
"following ppdet/ext_op/README.md", e)
|
||||
sys.stdout.flush()
|
||||
sys.exit(-1)
|
||||
|
||||
iou = rbox_iou(gt_bboxes_xc_yc, anchors_xc_yc)
|
||||
iou = iou.numpy()
|
||||
iou = iou.T
|
||||
|
||||
# every gt's anchor's index
|
||||
gt_bbox_anchor_inds = iou.argmax(axis=0)
|
||||
gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
|
||||
gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]
|
||||
|
||||
# every anchor's gt bbox's index
|
||||
anchor_gt_bbox_inds = iou.argmax(axis=1)
|
||||
anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]
|
||||
|
||||
# (1) set labels=-2 as default
|
||||
labels = np.ones((iou.shape[0], ), dtype=np.int32) * ignore_iof_thr
|
||||
|
||||
# (2) assign ignore
|
||||
labels[anchor_gt_bbox_iou < min_iou_thr] = ignore_iof_thr
|
||||
|
||||
# (3) assign neg_ids -1
|
||||
assign_neg_ids1 = anchor_gt_bbox_iou >= min_iou_thr
|
||||
assign_neg_ids2 = anchor_gt_bbox_iou < neg_iou_thr
|
||||
assign_neg_ids = np.logical_and(assign_neg_ids1, assign_neg_ids2)
|
||||
labels[assign_neg_ids] = -1
|
||||
|
||||
# anchor_gt_bbox_iou_inds
|
||||
# (4) assign max_iou as pos_ids >=0
|
||||
anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds]
|
||||
# gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr)
|
||||
labels[gt_bbox_anchor_iou_inds] = gt_labels[anchor_gt_bbox_iou_inds]
|
||||
|
||||
# (5) assign >= pos_iou_thr as pos_ids
|
||||
iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr
|
||||
iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids]
|
||||
labels[iou_pos_iou_thr_ids] = gt_labels[iou_pos_iou_thr_ids_box_inds]
|
||||
return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
|
||||
|
||||
def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd):
|
||||
|
||||
assert anchors.ndim == 2
|
||||
assert anchors.shape[1] == 5
|
||||
assert gt_bboxes.ndim == 2
|
||||
assert gt_bboxes.shape[1] == 5
|
||||
|
||||
pos_iou_thr = self.pos_iou_thr
|
||||
neg_iou_thr = self.neg_iou_thr
|
||||
min_iou_thr = self.min_iou_thr
|
||||
ignore_iof_thr = self.ignore_iof_thr
|
||||
|
||||
anchor_num = anchors.shape[0]
|
||||
|
||||
gt_bboxes = gt_bboxes
|
||||
is_crowd_slice = is_crowd
|
||||
not_crowd_inds = np.where(is_crowd_slice == 0)
|
||||
|
||||
# Step1: match anchor and gt_bbox
|
||||
anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = self.assign_anchor(
|
||||
anchors, gt_bboxes,
|
||||
gt_labels.reshape(-1), pos_iou_thr, neg_iou_thr, min_iou_thr,
|
||||
ignore_iof_thr)
|
||||
|
||||
# Step2: sample anchor
|
||||
pos_inds = np.where(labels >= 0)[0]
|
||||
neg_inds = np.where(labels == -1)[0]
|
||||
|
||||
# Step3: make output
|
||||
anchors_num = anchors.shape[0]
|
||||
bbox_targets = np.zeros_like(anchors)
|
||||
bbox_weights = np.zeros_like(anchors)
|
||||
bbox_gt_bboxes = np.zeros_like(anchors)
|
||||
pos_labels = np.zeros(anchors_num, dtype=np.int32)
|
||||
pos_labels_weights = np.zeros(anchors_num, dtype=np.float32)
|
||||
|
||||
pos_sampled_anchors = anchors[pos_inds]
|
||||
pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]]
|
||||
if len(pos_inds) > 0:
|
||||
pos_bbox_targets = self.rbox2delta(pos_sampled_anchors,
|
||||
pos_sampled_gt_boxes)
|
||||
bbox_targets[pos_inds, :] = pos_bbox_targets
|
||||
bbox_gt_bboxes[pos_inds, :] = pos_sampled_gt_boxes
|
||||
bbox_weights[pos_inds, :] = 1.0
|
||||
|
||||
pos_labels[pos_inds] = labels[pos_inds]
|
||||
pos_labels_weights[pos_inds] = 1.0
|
||||
|
||||
if len(neg_inds) > 0:
|
||||
pos_labels_weights[neg_inds] = 1.0
|
||||
return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights,
|
||||
bbox_gt_bboxes, pos_inds, neg_inds)
|
||||
Reference in New Issue
Block a user