更换文档检测模型

This commit is contained in:
2024-08-27 14:42:45 +08:00
parent aea6f19951
commit 1514e09c40
2072 changed files with 254336 additions and 4967 deletions

View File

@@ -0,0 +1,5 @@
from . import rpn_head
from . import embedding_rpn_head
from .rpn_head import *
from .embedding_rpn_head import *

View File

@@ -0,0 +1,266 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is based on
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/anchor_generator.py
import math
import paddle
import paddle.nn as nn
import numpy as np
from ppdet.core.workspace import register
__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator']
@register
class AnchorGenerator(nn.Layer):
"""
Generate anchors according to the feature maps
Args:
anchor_sizes (list[float] | list[list[float]]): The anchor sizes at
each feature point. list[float] means all feature levels share the
same sizes. list[list[float]] means the anchor sizes for
each level. The sizes stand for the scale of input size.
aspect_ratios (list[float] | list[list[float]]): The aspect ratios at
each feature point. list[float] means all feature levels share the
same ratios. list[list[float]] means the aspect ratios for
each level.
strides (list[float]): The strides of feature maps which generate
anchors
offset (float): The offset of the coordinate of anchors, default 0.
"""
def __init__(self,
anchor_sizes=[32, 64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
strides=[16.0],
variance=[1.0, 1.0, 1.0, 1.0],
offset=0.):
super(AnchorGenerator, self).__init__()
self.anchor_sizes = anchor_sizes
self.aspect_ratios = aspect_ratios
self.strides = strides
self.variance = variance
self.cell_anchors = self._calculate_anchors(len(strides))
self.offset = offset
def _broadcast_params(self, params, num_features):
if not isinstance(params[0], (list, tuple)): # list[float]
return [params] * num_features
if len(params) == 1:
return list(params) * num_features
return params
def generate_cell_anchors(self, sizes, aspect_ratios):
anchors = []
for size in sizes:
area = size**2.0
for aspect_ratio in aspect_ratios:
w = math.sqrt(area / aspect_ratio)
h = aspect_ratio * w
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
anchors.append([x0, y0, x1, y1])
return paddle.to_tensor(anchors, dtype='float32')
def _calculate_anchors(self, num_features):
sizes = self._broadcast_params(self.anchor_sizes, num_features)
aspect_ratios = self._broadcast_params(self.aspect_ratios, num_features)
cell_anchors = [
self.generate_cell_anchors(s, a)
for s, a in zip(sizes, aspect_ratios)
]
[
self.register_buffer(
t.name, t, persistable=False) for t in cell_anchors
]
return cell_anchors
def _create_grid_offsets(self, size, stride, offset):
grid_height, grid_width = size[0], size[1]
shifts_x = paddle.arange(
offset * stride, grid_width * stride, step=stride, dtype='float32')
shifts_y = paddle.arange(
offset * stride, grid_height * stride, step=stride, dtype='float32')
shift_y, shift_x = paddle.meshgrid(shifts_y, shifts_x)
shift_x = paddle.reshape(shift_x, [-1])
shift_y = paddle.reshape(shift_y, [-1])
return shift_x, shift_y
def _grid_anchors(self, grid_sizes):
anchors = []
for size, stride, base_anchors in zip(grid_sizes, self.strides,
self.cell_anchors):
shift_x, shift_y = self._create_grid_offsets(size, stride,
self.offset)
shifts = paddle.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
shifts = paddle.reshape(shifts, [-1, 1, 4])
base_anchors = paddle.reshape(base_anchors, [1, -1, 4])
anchors.append(paddle.reshape(shifts + base_anchors, [-1, 4]))
return anchors
def forward(self, input):
grid_sizes = [paddle.shape(feature_map)[-2:] for feature_map in input]
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
return anchors_over_all_feature_maps
@property
def num_anchors(self):
"""
Returns:
int: number of anchors at every pixel
location, on that feature map.
For example, if at every pixel we use anchors of 3 aspect
ratios and 5 sizes, the number of anchors is 15.
For FPN models, `num_anchors` on every feature map is the same.
"""
return len(self.cell_anchors[0])
@register
class RetinaAnchorGenerator(AnchorGenerator):
def __init__(self,
octave_base_scale=4,
scales_per_octave=3,
aspect_ratios=[0.5, 1.0, 2.0],
strides=[8.0, 16.0, 32.0, 64.0, 128.0],
variance=[1.0, 1.0, 1.0, 1.0],
offset=0.0):
anchor_sizes = []
for s in strides:
anchor_sizes.append([
s * octave_base_scale * 2**(i/scales_per_octave) \
for i in range(scales_per_octave)])
super(RetinaAnchorGenerator, self).__init__(
anchor_sizes=anchor_sizes,
aspect_ratios=aspect_ratios,
strides=strides,
variance=variance,
offset=offset)
@register
class S2ANetAnchorGenerator(nn.Layer):
"""
AnchorGenerator by paddle
"""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
super(S2ANetAnchorGenerator, self).__init__()
self.base_size = base_size
self.scales = paddle.to_tensor(scales)
self.ratios = paddle.to_tensor(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
@property
def num_base_anchors(self):
return self.base_anchors.shape[0]
def gen_base_anchors(self):
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = paddle.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
else:
ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
base_anchors = paddle.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1)
base_anchors = paddle.round(base_anchors)
return base_anchors
def _meshgrid(self, x, y, row_major=True):
yy, xx = paddle.meshgrid(y, x)
yy = yy.reshape([-1])
xx = xx.reshape([-1])
if row_major:
return xx, yy
else:
return yy, xx
def forward(self, featmap_size, stride=16):
# featmap_size*stride project it to original area
feat_h = featmap_size[0]
feat_w = featmap_size[1]
shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
all_anchors = self.base_anchors[:, :] + shifts[:, :]
all_anchors = all_anchors.cast(paddle.float32).reshape(
[feat_h * feat_w, 4])
all_anchors = self.rect2rbox(all_anchors)
return all_anchors
def valid_flags(self, featmap_size, valid_size):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = paddle.zeros([feat_w], dtype='int32')
valid_y = paddle.zeros([feat_h], dtype='int32')
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = paddle.reshape(valid, [-1, 1])
valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1])
return valid
def rect2rbox(self, bboxes):
"""
:param bboxes: shape (L, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (L, 5) (x_ctr, y_ctr, w, h, angle)
"""
x1, y1, x2, y2 = paddle.split(bboxes, 4, axis=-1)
x_ctr = (x1 + x2) / 2.0
y_ctr = (y1 + y2) / 2.0
edges1 = paddle.abs(x2 - x1)
edges2 = paddle.abs(y2 - y1)
rbox_w = paddle.maximum(edges1, edges2)
rbox_h = paddle.minimum(edges1, edges2)
# set angle
inds = edges1 < edges2
inds = paddle.cast(inds, paddle.float32)
rboxes_angle = inds * np.pi / 2.0
rboxes = paddle.concat(
(x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=-1)
return rboxes

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is referenced from: https://github.com/open-mmlab/mmdetection
import paddle
from paddle import nn
from ppdet.core.workspace import register
__all__ = ['EmbeddingRPNHead']
@register
class EmbeddingRPNHead(nn.Layer):
__shared__ = ['proposal_embedding_dim']
def __init__(self, num_proposals, proposal_embedding_dim=256):
super(EmbeddingRPNHead, self).__init__()
self.num_proposals = num_proposals
self.proposal_embedding_dim = proposal_embedding_dim
self._init_layers()
self._init_weights()
def _init_layers(self):
self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
self.init_proposal_features = nn.Embedding(self.num_proposals,
self.proposal_embedding_dim)
def _init_weights(self):
init_bboxes = paddle.empty_like(self.init_proposal_bboxes.weight)
init_bboxes[:, :2] = 0.5
init_bboxes[:, 2:] = 1.0
self.init_proposal_bboxes.weight.set_value(init_bboxes)
@staticmethod
def bbox_cxcywh_to_xyxy(x):
cxcy, wh = paddle.split(x, 2, axis=-1)
return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)
def forward(self, img_whwh):
proposal_bboxes = self.init_proposal_bboxes.weight.clone()
proposal_bboxes = self.bbox_cxcywh_to_xyxy(proposal_bboxes)
proposal_bboxes = proposal_bboxes.unsqueeze(0) * img_whwh.unsqueeze(1)
proposal_features = self.init_proposal_features.weight.clone()
proposal_features = proposal_features.unsqueeze(0).tile(
[img_whwh.shape[0], 1, 1])
return proposal_bboxes, proposal_features

View File

@@ -0,0 +1,83 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ppdet.core.workspace import register, serializable
from .. import ops
@register
@serializable
class ProposalGenerator(object):
"""
Proposal generation module
For more details, please refer to the document of generate_proposals
in ppdet/modeing/ops.py
Args:
pre_nms_top_n (int): Number of total bboxes to be kept per
image before NMS. default 6000
post_nms_top_n (int): Number of total bboxes to be kept per
image after NMS. default 1000
nms_thresh (float): Threshold in NMS. default 0.5
min_size (flaot): Remove predicted boxes with either height or
width < min_size. default 0.1
eta (float): Apply in adaptive NMS, if adaptive `threshold > 0.5`,
`adaptive_threshold = adaptive_threshold * eta` in each iteration.
default 1.
topk_after_collect (bool): whether to adopt topk after batch
collection. If topk_after_collect is true, box filter will not be
used after NMS at each image in proposal generation. default false
"""
def __init__(self,
pre_nms_top_n=12000,
post_nms_top_n=2000,
nms_thresh=.5,
min_size=.1,
eta=1.,
topk_after_collect=False):
super(ProposalGenerator, self).__init__()
self.pre_nms_top_n = pre_nms_top_n
self.post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = min_size
self.eta = eta
self.topk_after_collect = topk_after_collect
def __call__(self, scores, bbox_deltas, anchors, im_shape):
top_n = self.pre_nms_top_n if self.topk_after_collect else self.post_nms_top_n
variances = paddle.ones_like(anchors)
if hasattr(paddle.vision.ops, "generate_proposals"):
generate_proposals = getattr(paddle.vision.ops,
"generate_proposals")
else:
generate_proposals = ops.generate_proposals
rpn_rois, rpn_rois_prob, rpn_rois_num = generate_proposals(
scores,
bbox_deltas,
im_shape,
anchors,
variances,
pre_nms_top_n=self.pre_nms_top_n,
post_nms_top_n=top_n,
nms_thresh=self.nms_thresh,
min_size=self.min_size,
eta=self.eta,
return_rois_num=True)
return rpn_rois, rpn_rois_prob, rpn_rois_num, self.post_nms_top_n

View File

@@ -0,0 +1,313 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal
from ppdet.core.workspace import register
from .anchor_generator import AnchorGenerator
from .target_layer import RPNTargetAssign
from .proposal_generator import ProposalGenerator
from ..cls_utils import _get_class_default_kwargs
class RPNFeat(nn.Layer):
"""
Feature extraction in RPN head
Args:
in_channel (int): Input channel
out_channel (int): Output channel
"""
def __init__(self, in_channel=1024, out_channel=1024):
super(RPNFeat, self).__init__()
# rpn feat is shared with each level
self.rpn_conv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0., std=0.01)))
self.rpn_conv.skip_quant = True
def forward(self, feats):
rpn_feats = []
for feat in feats:
rpn_feats.append(F.relu(self.rpn_conv(feat)))
return rpn_feats
@register
class RPNHead(nn.Layer):
"""
Region Proposal Network
Args:
anchor_generator (dict): configure of anchor generation
rpn_target_assign (dict): configure of rpn targets assignment
train_proposal (dict): configure of proposals generation
at the stage of training
test_proposal (dict): configure of proposals generation
at the stage of prediction
in_channel (int): channel of input feature maps which can be
derived by from_config
"""
__shared__ = ['export_onnx']
__inject__ = ['loss_rpn_bbox']
def __init__(self,
anchor_generator=_get_class_default_kwargs(AnchorGenerator),
rpn_target_assign=_get_class_default_kwargs(RPNTargetAssign),
train_proposal=_get_class_default_kwargs(ProposalGenerator,
12000, 2000),
test_proposal=_get_class_default_kwargs(ProposalGenerator),
in_channel=1024,
export_onnx=False,
loss_rpn_bbox=None):
super(RPNHead, self).__init__()
self.anchor_generator = anchor_generator
self.rpn_target_assign = rpn_target_assign
self.train_proposal = train_proposal
self.test_proposal = test_proposal
self.export_onnx = export_onnx
if isinstance(anchor_generator, dict):
self.anchor_generator = AnchorGenerator(**anchor_generator)
if isinstance(rpn_target_assign, dict):
self.rpn_target_assign = RPNTargetAssign(**rpn_target_assign)
if isinstance(train_proposal, dict):
self.train_proposal = ProposalGenerator(**train_proposal)
if isinstance(test_proposal, dict):
self.test_proposal = ProposalGenerator(**test_proposal)
self.loss_rpn_bbox = loss_rpn_bbox
num_anchors = self.anchor_generator.num_anchors
self.rpn_feat = RPNFeat(in_channel, in_channel)
# rpn head is shared with each level
# rpn roi classification scores
self.rpn_rois_score = nn.Conv2D(
in_channels=in_channel,
out_channels=num_anchors,
kernel_size=1,
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0., std=0.01)))
self.rpn_rois_score.skip_quant = True
# rpn roi bbox regression deltas
self.rpn_rois_delta = nn.Conv2D(
in_channels=in_channel,
out_channels=4 * num_anchors,
kernel_size=1,
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0., std=0.01)))
self.rpn_rois_delta.skip_quant = True
@classmethod
def from_config(cls, cfg, input_shape):
# FPN share same rpn head
if isinstance(input_shape, (list, tuple)):
input_shape = input_shape[0]
return {'in_channel': input_shape.channels}
def forward(self, feats, inputs):
rpn_feats = self.rpn_feat(feats)
scores = []
deltas = []
for rpn_feat in rpn_feats:
rrs = self.rpn_rois_score(rpn_feat)
rrd = self.rpn_rois_delta(rpn_feat)
scores.append(rrs)
deltas.append(rrd)
anchors = self.anchor_generator(rpn_feats)
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs)
if self.training:
loss = self.get_loss(scores, deltas, anchors, inputs)
return rois, rois_num, loss
else:
return rois, rois_num, None
def _gen_proposal(self, scores, bbox_deltas, anchors, inputs):
"""
scores (list[Tensor]): Multi-level scores prediction
bbox_deltas (list[Tensor]): Multi-level deltas prediction
anchors (list[Tensor]): Multi-level anchors
inputs (dict): ground truth info
"""
prop_gen = self.train_proposal if self.training else self.test_proposal
im_shape = inputs['im_shape']
# Collect multi-level proposals for each batch
# Get 'topk' of them as final output
if self.export_onnx:
# bs = 1 when exporting onnx
onnx_rpn_rois_list = []
onnx_rpn_prob_list = []
onnx_rpn_rois_num_list = []
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
anchors):
onnx_rpn_rois, onnx_rpn_rois_prob, onnx_rpn_rois_num, onnx_post_nms_top_n = prop_gen(
scores=rpn_score[0:1],
bbox_deltas=rpn_delta[0:1],
anchors=anchor,
im_shape=im_shape[0:1])
onnx_rpn_rois_list.append(onnx_rpn_rois)
onnx_rpn_prob_list.append(onnx_rpn_rois_prob)
onnx_rpn_rois_num_list.append(onnx_rpn_rois_num)
onnx_rpn_rois = paddle.concat(onnx_rpn_rois_list)
onnx_rpn_prob = paddle.concat(onnx_rpn_prob_list).flatten()
onnx_top_n = paddle.to_tensor(onnx_post_nms_top_n).cast('int32')
onnx_num_rois = paddle.shape(onnx_rpn_prob)[0].cast('int32')
k = paddle.minimum(onnx_top_n, onnx_num_rois)
onnx_topk_prob, onnx_topk_inds = paddle.topk(onnx_rpn_prob, k)
onnx_topk_rois = paddle.gather(onnx_rpn_rois, onnx_topk_inds)
# TODO(wangguanzhong): Now bs_rois_collect in export_onnx is moved outside conditional branch
# due to problems in dy2static of paddle. Will fix it when updating paddle framework.
# bs_rois_collect = [onnx_topk_rois]
# bs_rois_num_collect = paddle.shape(onnx_topk_rois)[0]
else:
bs_rois_collect = []
bs_rois_num_collect = []
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# Generate proposals for each level and each batch.
# Discard batch-computing to avoid sorting bbox cross different batches.
for i in range(batch_size):
rpn_rois_list = []
rpn_prob_list = []
rpn_rois_num_list = []
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
anchors):
rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
scores=rpn_score[i:i + 1],
bbox_deltas=rpn_delta[i:i + 1],
anchors=anchor,
im_shape=im_shape[i:i + 1])
rpn_rois_list.append(rpn_rois)
rpn_prob_list.append(rpn_rois_prob)
rpn_rois_num_list.append(rpn_rois_num)
if len(scores) > 1:
rpn_rois = paddle.concat(rpn_rois_list)
rpn_prob = paddle.concat(rpn_prob_list).flatten()
num_rois = paddle.shape(rpn_prob)[0].cast('int32')
if num_rois > post_nms_top_n:
topk_prob, topk_inds = paddle.topk(rpn_prob,
post_nms_top_n)
topk_rois = paddle.gather(rpn_rois, topk_inds)
else:
topk_rois = rpn_rois
topk_prob = rpn_prob
else:
topk_rois = rpn_rois_list[0]
topk_prob = rpn_prob_list[0].flatten()
bs_rois_collect.append(topk_rois)
bs_rois_num_collect.append(paddle.shape(topk_rois)[0:1])
bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
if self.export_onnx:
output_rois = [onnx_topk_rois]
output_rois_num = paddle.shape(onnx_topk_rois)[0]
else:
output_rois = bs_rois_collect
output_rois_num = bs_rois_num_collect
return output_rois, output_rois_num
def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
"""
pred_scores (list[Tensor]): Multi-level scores prediction
pred_deltas (list[Tensor]): Multi-level deltas prediction
anchors (list[Tensor]): Multi-level anchors
inputs (dict): ground truth info, including im, gt_bbox, gt_score
"""
anchors = [paddle.reshape(a, shape=(-1, 4)) for a in anchors]
anchors = paddle.concat(anchors)
scores = [
paddle.reshape(
paddle.transpose(
v, perm=[0, 2, 3, 1]),
shape=(v.shape[0], -1, 1)) for v in pred_scores
]
scores = paddle.concat(scores, axis=1)
deltas = [
paddle.reshape(
paddle.transpose(
v, perm=[0, 2, 3, 1]),
shape=(v.shape[0], -1, 4)) for v in pred_deltas
]
deltas = paddle.concat(deltas, axis=1)
score_tgt, bbox_tgt, loc_tgt, norm = self.rpn_target_assign(inputs,
anchors)
scores = paddle.reshape(x=scores, shape=(-1, ))
deltas = paddle.reshape(x=deltas, shape=(-1, 4))
score_tgt = paddle.concat(score_tgt)
score_tgt.stop_gradient = True
pos_mask = score_tgt == 1
pos_ind = paddle.nonzero(pos_mask)
valid_mask = score_tgt >= 0
valid_ind = paddle.nonzero(valid_mask)
# cls loss
if valid_ind.shape[0] == 0:
loss_rpn_cls = paddle.zeros([1], dtype='float32')
else:
score_pred = paddle.gather(scores, valid_ind)
score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
score_label.stop_gradient = True
loss_rpn_cls = F.binary_cross_entropy_with_logits(
logit=score_pred, label=score_label, reduction="sum")
# reg loss
if pos_ind.shape[0] == 0:
loss_rpn_reg = paddle.zeros([1], dtype='float32')
else:
loc_pred = paddle.gather(deltas, pos_ind)
loc_tgt = paddle.concat(loc_tgt)
loc_tgt = paddle.gather(loc_tgt, pos_ind)
loc_tgt.stop_gradient = True
if self.loss_rpn_bbox is None:
loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum()
else:
loss_rpn_reg = self.loss_rpn_bbox(loc_pred, loc_tgt).sum()
return {
'loss_rpn_cls': loss_rpn_cls / norm,
'loss_rpn_reg': loss_rpn_reg / norm
}

View File

@@ -0,0 +1,678 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from ..bbox_utils import bbox2delta, bbox_overlaps
def rpn_anchor_target(anchors,
gt_boxes,
rpn_batch_size_per_im,
rpn_positive_overlap,
rpn_negative_overlap,
rpn_fg_fraction,
use_random=True,
batch_size=1,
ignore_thresh=-1,
is_crowd=None,
weights=[1., 1., 1., 1.],
assign_on_cpu=False):
tgt_labels = []
tgt_bboxes = []
tgt_deltas = []
for i in range(batch_size):
gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
# Step1: match anchor and gt_bbox
matches, match_labels = label_box(
anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
ignore_thresh, is_crowd_i, assign_on_cpu)
# Step2: sample anchor
fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
rpn_fg_fraction, 0, use_random)
# Fill with the ignore label (-1), then set positive and negative labels
labels = paddle.full(match_labels.shape, -1, dtype='int32')
if bg_inds.shape[0] > 0:
labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
if fg_inds.shape[0] > 0:
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
# Step3: make output
if gt_bbox.shape[0] == 0:
matched_gt_boxes = paddle.zeros([matches.shape[0], 4])
tgt_delta = paddle.zeros([matches.shape[0], 4])
else:
matched_gt_boxes = paddle.gather(gt_bbox, matches)
tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
matched_gt_boxes.stop_gradient = True
tgt_delta.stop_gradient = True
labels.stop_gradient = True
tgt_labels.append(labels)
tgt_bboxes.append(matched_gt_boxes)
tgt_deltas.append(tgt_delta)
return tgt_labels, tgt_bboxes, tgt_deltas
def label_box(anchors,
gt_boxes,
positive_overlap,
negative_overlap,
allow_low_quality,
ignore_thresh,
is_crowd=None,
assign_on_cpu=False):
if assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device("cpu")
iou = bbox_overlaps(gt_boxes, anchors)
paddle.set_device(device)
else:
iou = bbox_overlaps(gt_boxes, anchors)
n_gt = gt_boxes.shape[0]
if n_gt == 0 or is_crowd is None:
n_gt_crowd = 0
else:
n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
if iou.shape[0] == 0 or n_gt_crowd == n_gt:
# No truth, assign everything to background
default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
return default_matches, default_match_labels
# if ignore_thresh > 0, remove anchor if it is closed to
# one of the crowded ground-truth
if n_gt_crowd > 0:
N_a = anchors.shape[0]
ones = paddle.ones([N_a])
mask = is_crowd * ones
if ignore_thresh > 0:
crowd_iou = iou * mask
valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
axis=0) > 0).cast('float32')
iou = iou * (1 - valid) - valid
# ignore the iou between anchor and crowded ground-truth
iou = iou * (1 - mask) - mask
matched_vals, matches = paddle.topk(iou, k=1, axis=0)
match_labels = paddle.full(matches.shape, -1, dtype='int32')
# set ignored anchor with iou = -1
neg_cond = paddle.logical_and(matched_vals > -1,
matched_vals < negative_overlap)
match_labels = paddle.where(neg_cond,
paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels)
if allow_low_quality:
highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
pred_inds_with_highest_quality = paddle.logical_and(
iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
0, keepdim=True)
match_labels = paddle.where(pred_inds_with_highest_quality > 0,
paddle.ones_like(match_labels),
match_labels)
matches = matches.flatten()
match_labels = match_labels.flatten()
return matches, match_labels
def subsample_labels(labels,
num_samples,
fg_fraction,
bg_label=0,
use_random=True):
positive = paddle.nonzero(
paddle.logical_and(labels != -1, labels != bg_label))
negative = paddle.nonzero(labels == bg_label)
fg_num = int(num_samples * fg_fraction)
fg_num = min(positive.numel(), fg_num)
bg_num = num_samples - fg_num
bg_num = min(negative.numel(), bg_num)
if fg_num == 0 and bg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
bg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
# randomly select positive and negative examples
negative = negative.cast('int32').flatten()
bg_perm = paddle.randperm(negative.numel(), dtype='int32')
bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
if use_random:
bg_inds = paddle.gather(negative, bg_perm)
else:
bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
if fg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
positive = positive.cast('int32').flatten()
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
if use_random:
fg_inds = paddle.gather(positive, fg_perm)
else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
return fg_inds, bg_inds
def generate_proposal_target(rpn_rois,
gt_classes,
gt_boxes,
batch_size_per_im,
fg_fraction,
fg_thresh,
bg_thresh,
num_classes,
ignore_thresh=-1.,
is_crowd=None,
use_random=True,
is_cascade=False,
cascade_iou=0.5,
assign_on_cpu=False,
add_gt_as_proposals=True):
rois_with_gt = []
tgt_labels = []
tgt_bboxes = []
tgt_gt_inds = []
new_rois_num = []
# In cascade rcnn, the threshold for foreground and background
# is used from cascade_iou
fg_thresh = cascade_iou if is_cascade else fg_thresh
bg_thresh = cascade_iou if is_cascade else bg_thresh
for i, rpn_roi in enumerate(rpn_rois):
gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
gt_class = paddle.squeeze(gt_classes[i], axis=-1)
# Concat RoIs and gt boxes except cascade rcnn or none gt
if add_gt_as_proposals and gt_bbox.shape[0] > 0:
bbox = paddle.concat([rpn_roi, gt_bbox])
else:
bbox = rpn_roi
# Step1: label bbox
matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
False, ignore_thresh, is_crowd_i,
assign_on_cpu)
# Step2: sample bbox
sampled_inds, sampled_gt_classes = sample_bbox(
matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
num_classes, use_random, is_cascade)
# Step3: make output
rois_per_image = bbox if is_cascade else paddle.gather(bbox,
sampled_inds)
sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
sampled_inds)
if gt_bbox.shape[0] > 0:
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
else:
num = rois_per_image.shape[0]
sampled_bbox = paddle.zeros([num, 4], dtype='float32')
rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True
sampled_bbox.stop_gradient = True
tgt_labels.append(sampled_gt_classes)
tgt_bboxes.append(sampled_bbox)
rois_with_gt.append(rois_per_image)
tgt_gt_inds.append(sampled_gt_ind)
new_rois_num.append(paddle.shape(sampled_inds)[0:1])
new_rois_num = paddle.concat(new_rois_num)
return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
def sample_bbox(matches,
match_labels,
gt_classes,
batch_size_per_im,
fg_fraction,
num_classes,
use_random=True,
is_cascade=False):
n_gt = gt_classes.shape[0]
if n_gt == 0:
# No truth, assign everything to background
gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
#return matches, match_labels + num_classes
else:
gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0,
paddle.ones_like(gt_classes) * num_classes,
gt_classes)
gt_classes = paddle.where(match_labels == -1,
paddle.ones_like(gt_classes) * -1, gt_classes)
if is_cascade:
index = paddle.arange(matches.shape[0])
return index, gt_classes
rois_per_image = int(batch_size_per_im)
fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
num_classes, use_random)
if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
# fake output labeled with -1 when all boxes are neither
# foreground nor background
sampled_inds = paddle.zeros([1], dtype='int32')
else:
sampled_inds = paddle.concat([fg_inds, bg_inds])
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
return sampled_inds, sampled_gt_classes
def polygons_to_mask(polygons, height, width):
"""
Convert the polygons to mask format
Args:
polygons (list[ndarray]): each array has shape (Nx2,)
height (int): mask height
width (int): mask width
Returns:
ndarray: a bool mask of shape (height, width)
"""
import pycocotools.mask as mask_util
assert len(polygons) > 0, "COCOAPI does not support empty polygons"
rles = mask_util.frPyObjects(polygons, height, width)
rle = mask_util.merge(rles)
return mask_util.decode(rle).astype(np.bool_)
def rasterize_polygons_within_box(poly, box, resolution):
w, h = box[2] - box[0], box[3] - box[1]
polygons = [np.asarray(p, dtype=np.float64) for p in poly]
for p in polygons:
p[0::2] = p[0::2] - box[0]
p[1::2] = p[1::2] - box[1]
ratio_h = resolution / max(h, 0.1)
ratio_w = resolution / max(w, 0.1)
if ratio_h == ratio_w:
for p in polygons:
p *= ratio_h
else:
for p in polygons:
p[0::2] *= ratio_w
p[1::2] *= ratio_h
# 3. Rasterize the polygons with coco api
mask = polygons_to_mask(polygons, resolution, resolution)
mask = paddle.to_tensor(mask, dtype='int32')
return mask
def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
num_classes, resolution):
mask_rois = []
mask_rois_num = []
tgt_masks = []
tgt_classes = []
mask_index = []
tgt_weights = []
for k in range(len(rois)):
labels_per_im = labels_int32[k]
# select rois labeled with foreground
fg_inds = paddle.nonzero(
paddle.logical_and(labels_per_im != -1, labels_per_im !=
num_classes))
has_fg = True
# generate fake roi if foreground is empty
if fg_inds.numel() == 0:
has_fg = False
fg_inds = paddle.ones([1, 1], dtype='int64')
inds_per_im = sampled_gt_inds[k]
inds_per_im = paddle.gather(inds_per_im, fg_inds)
rois_per_im = rois[k]
fg_rois = paddle.gather(rois_per_im, fg_inds)
# Copy the foreground roi to cpu
# to generate mask target with ground-truth
boxes = fg_rois.numpy()
gt_segms_per_im = gt_segms[k]
new_segm = []
inds_per_im = inds_per_im.numpy()
if len(gt_segms_per_im) > 0:
for i in inds_per_im:
new_segm.append(gt_segms_per_im[i])
fg_inds_new = fg_inds.reshape([-1]).numpy()
results = []
if len(gt_segms_per_im) > 0:
for j in range(fg_inds_new.shape[0]):
results.append(
rasterize_polygons_within_box(new_segm[j], boxes[j],
resolution))
else:
results.append(paddle.ones([resolution, resolution], dtype='int32'))
fg_classes = paddle.gather(labels_per_im, fg_inds)
weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
if not has_fg:
# now all sampled classes are background
# which will cause error in loss calculation,
# make fake classes with weight of 0.
fg_classes = paddle.zeros([1], dtype='int32')
weight = weight - 1
tgt_mask = paddle.stack(results)
tgt_mask.stop_gradient = True
fg_rois.stop_gradient = True
mask_index.append(fg_inds)
mask_rois.append(fg_rois)
mask_rois_num.append(paddle.shape(fg_rois)[0:1])
tgt_classes.append(fg_classes)
tgt_masks.append(tgt_mask)
tgt_weights.append(weight)
mask_index = paddle.concat(mask_index)
mask_rois_num = paddle.concat(mask_rois_num)
tgt_classes = paddle.concat(tgt_classes, axis=0)
tgt_masks = paddle.concat(tgt_masks, axis=0)
tgt_weights = paddle.concat(tgt_weights, axis=0)
return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
if len(pos_inds) <= num_expected:
return pos_inds
else:
unique_gt_inds = np.unique(max_classes[pos_inds])
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = np.nonzero(max_classes == i)[0]
before_len = len(inds)
inds = list(set(inds) & set(pos_inds))
after_len = len(inds)
if len(inds) > num_per_gt:
inds = np.random.choice(inds, size=num_per_gt, replace=False)
sampled_inds.extend(list(inds)) # combine as a new sampler
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
"sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
len(sampled_inds), len(extra_inds), len(pos_inds))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, size=num_extra, replace=False)
sampled_inds.extend(extra_inds.tolist())
elif len(sampled_inds) > num_expected:
sampled_inds = np.random.choice(
sampled_inds, size=num_expected, replace=False)
return paddle.to_tensor(sampled_inds)
def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
num_bins, bg_thresh):
max_iou = max_overlaps.max()
iou_interval = (max_iou - floor_thr) / num_bins
per_num_expected = int(num_expected / num_bins)
sampled_inds = []
for i in range(num_bins):
start_iou = floor_thr + i * iou_interval
end_iou = floor_thr + (i + 1) * iou_interval
tmp_set = set(
np.where(
np.logical_and(max_overlaps >= start_iou, max_overlaps <
end_iou))[0])
tmp_inds = list(tmp_set & full_set)
if len(tmp_inds) > per_num_expected:
tmp_sampled_set = np.random.choice(
tmp_inds, size=per_num_expected, replace=False)
else:
tmp_sampled_set = np.array(tmp_inds, dtype=np.int32)
sampled_inds.append(tmp_sampled_set)
sampled_inds = np.concatenate(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(full_set - set(sampled_inds)))
assert len(sampled_inds) + len(extra_inds) == len(full_set), \
"sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
len(sampled_inds), len(extra_inds), len(full_set))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
sampled_inds = np.concatenate([sampled_inds, extra_inds])
return sampled_inds
def libra_sample_neg(max_overlaps,
max_classes,
neg_inds,
num_expected,
floor_thr=-1,
floor_fraction=0,
num_bins=3,
bg_thresh=0.5):
if len(neg_inds) <= num_expected:
return neg_inds
else:
# balance sampling for negative samples
neg_set = set(neg_inds.tolist())
if floor_thr > 0:
floor_set = set(
np.where(
np.logical_and(max_overlaps >= 0, max_overlaps < floor_thr))
[0])
iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
elif floor_thr == 0:
floor_set = set(np.where(max_overlaps == 0)[0])
iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
else:
floor_set = set()
iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
floor_thr = 0
floor_neg_inds = list(floor_set & neg_set)
iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
if num_bins >= 2:
iou_sampled_inds = libra_sample_via_interval(
max_overlaps,
set(iou_sampling_neg_inds), num_expected_iou_sampling,
floor_thr, num_bins, bg_thresh)
else:
iou_sampled_inds = np.random.choice(
iou_sampling_neg_inds,
size=num_expected_iou_sampling,
replace=False)
else:
iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int32)
num_expected_floor = num_expected - len(iou_sampled_inds)
if len(floor_neg_inds) > num_expected_floor:
sampled_floor_inds = np.random.choice(
floor_neg_inds, size=num_expected_floor, replace=False)
else:
sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int32)
sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, size=num_extra, replace=False)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
return paddle.to_tensor(sampled_inds)
def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
negative_overlap, num_classes):
# TODO: use paddle API to speed up
gt_classes = gt_classes.numpy()
gt_overlaps = np.zeros((anchors.shape[0], num_classes))
matches = np.zeros((anchors.shape[0]), dtype=np.int32)
if len(gt_boxes) > 0:
proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
overlaps_max = proposal_to_gt_overlaps.max(axis=1)
# Boxes which with non-zero overlap with gt boxes
overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
overlapped_boxes_ind]]
for idx in range(len(overlapped_boxes_ind)):
gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
idx]] = overlaps_max[overlapped_boxes_ind[idx]]
matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
overlapped_boxes_ind[idx]]
gt_overlaps = paddle.to_tensor(gt_overlaps)
matches = paddle.to_tensor(matches)
matched_vals = paddle.max(gt_overlaps, axis=1)
match_labels = paddle.full(matches.shape, -1, dtype='int32')
match_labels = paddle.where(matched_vals < negative_overlap,
paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels)
return matches, match_labels, matched_vals
def libra_sample_bbox(matches,
match_labels,
matched_vals,
gt_classes,
batch_size_per_im,
num_classes,
fg_fraction,
fg_thresh,
bg_thresh,
num_bins,
use_random=True,
is_cascade_rcnn=False):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
bg_rois_per_im = rois_per_image - fg_rois_per_im
if is_cascade_rcnn:
fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
bg_inds = paddle.nonzero(matched_vals < bg_thresh)
else:
matched_vals_np = matched_vals.numpy()
match_labels_np = match_labels.numpy()
# sample fg
fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
if (fg_inds.shape[0] > fg_nums) and use_random:
fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
fg_inds.numpy(), fg_rois_per_im)
fg_inds = fg_inds[:fg_nums]
# sample bg
bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
if (bg_inds.shape[0] > bg_nums) and use_random:
bg_inds = libra_sample_neg(
matched_vals_np,
match_labels_np,
bg_inds.numpy(),
bg_rois_per_im,
num_bins=num_bins,
bg_thresh=bg_thresh)
bg_inds = bg_inds[:bg_nums]
sampled_inds = paddle.concat([fg_inds, bg_inds])
gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0,
paddle.ones_like(gt_classes) * num_classes,
gt_classes)
gt_classes = paddle.where(match_labels == -1,
paddle.ones_like(gt_classes) * -1, gt_classes)
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
return sampled_inds, sampled_gt_classes
def libra_generate_proposal_target(rpn_rois,
gt_classes,
gt_boxes,
batch_size_per_im,
fg_fraction,
fg_thresh,
bg_thresh,
num_classes,
use_random=True,
is_cascade_rcnn=False,
max_overlaps=None,
num_bins=3):
rois_with_gt = []
tgt_labels = []
tgt_bboxes = []
sampled_max_overlaps = []
tgt_gt_inds = []
new_rois_num = []
for i, rpn_roi in enumerate(rpn_rois):
max_overlap = max_overlaps[i] if is_cascade_rcnn else None
gt_bbox = gt_boxes[i]
gt_class = paddle.squeeze(gt_classes[i], axis=-1)
if is_cascade_rcnn:
rpn_roi = filter_roi(rpn_roi, max_overlap)
bbox = paddle.concat([rpn_roi, gt_bbox])
# Step1: label bbox
matches, match_labels, matched_vals = libra_label_box(
bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
# Step2: sample bbox
sampled_inds, sampled_gt_classes = libra_sample_bbox(
matches, match_labels, matched_vals, gt_class, batch_size_per_im,
num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
use_random, is_cascade_rcnn)
# Step3: make output
rois_per_image = paddle.gather(bbox, sampled_inds)
sampled_gt_ind = paddle.gather(matches, sampled_inds)
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
sampled_overlap = paddle.gather(matched_vals, sampled_inds)
rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True
sampled_bbox.stop_gradient = True
sampled_overlap.stop_gradient = True
tgt_labels.append(sampled_gt_classes)
tgt_bboxes.append(sampled_bbox)
rois_with_gt.append(rois_per_image)
sampled_max_overlaps.append(sampled_overlap)
tgt_gt_inds.append(sampled_gt_ind)
new_rois_num.append(paddle.shape(sampled_inds)[0:1])
new_rois_num = paddle.concat(new_rois_num)
# rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num

View File

@@ -0,0 +1,481 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
from ppdet.core.workspace import register, serializable
from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target
import numpy as np
@register
@serializable
class RPNTargetAssign(object):
__shared__ = ['assign_on_cpu']
"""
RPN targets assignment module
The assignment consists of three steps:
1. Match anchor and ground-truth box, label the anchor with foreground
or background sample
2. Sample anchors to keep the properly ratio between foreground and
background
3. Generate the targets for classification and regression branch
Args:
batch_size_per_im (int): Total number of RPN samples per image.
default 256
fg_fraction (float): Fraction of anchors that is labeled
foreground, default 0.5
positive_overlap (float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be
a foreground sample. default 0.7
negative_overlap (float): Maximum overlap allowed between an anchor
and ground-truth box for the (anchor, gt box) pair to be
a background sample. default 0.3
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and
background boxes, default true.
assign_on_cpu (bool): In case the number of gt box is too large,
compute IoU on CPU, default false.
"""
def __init__(self,
batch_size_per_im=256,
fg_fraction=0.5,
positive_overlap=0.7,
negative_overlap=0.3,
ignore_thresh=-1.,
use_random=True,
assign_on_cpu=False):
super(RPNTargetAssign, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh
self.use_random = use_random
self.assign_on_cpu = assign_on_cpu
def __call__(self, inputs, anchors):
"""
inputs: ground-truth instances.
anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
"""
gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
batch_size = len(gt_boxes)
tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
anchors,
gt_boxes,
self.batch_size_per_im,
self.positive_overlap,
self.negative_overlap,
self.fg_fraction,
self.use_random,
batch_size,
self.ignore_thresh,
is_crowd,
assign_on_cpu=self.assign_on_cpu)
norm = self.batch_size_per_im * batch_size
return tgt_labels, tgt_bboxes, tgt_deltas, norm
@register
class BBoxAssigner(object):
__shared__ = ['num_classes', 'assign_on_cpu']
"""
RCNN targets assignment module
The assignment consists of three steps:
1. Match RoIs and ground-truth box, label the RoIs with foreground
or background sample
2. Sample anchors to keep the properly ratio between foreground and
background
3. Generate the targets for classification and regression branch
Args:
batch_size_per_im (int): Total number of RoIs per image.
default 512
fg_fraction (float): Fraction of RoIs that is labeled
foreground, default 0.25
fg_thresh (float): Minimum overlap required between a RoI
and ground-truth box for the (roi, gt box) pair to be
a foreground sample. default 0.5
bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and
background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and
background of each stage, which is only used In Cascade RCNN.
num_classes (int): The number of class.
assign_on_cpu (bool): In case the number of gt box is too large,
compute IoU on CPU, default false.
"""
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=.5,
bg_thresh=.5,
ignore_thresh=-1.,
use_random=True,
cascade_iou=[0.5, 0.6, 0.7],
num_classes=80,
assign_on_cpu=False):
super(BBoxAssigner, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh
self.ignore_thresh = ignore_thresh
self.use_random = use_random
self.cascade_iou = cascade_iou
self.num_classes = num_classes
self.assign_on_cpu = assign_on_cpu
def __call__(self,
rpn_rois,
rpn_rois_num,
inputs,
stage=0,
is_cascade=False,
add_gt_as_proposals=True):
gt_classes = inputs['gt_class']
gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds
# new_rois_num
outs = generate_proposal_target(
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
self.ignore_thresh, is_crowd, self.use_random, is_cascade,
self.cascade_iou[stage], self.assign_on_cpu, add_gt_as_proposals)
rois = outs[0]
rois_num = outs[-1]
# tgt_labels, tgt_bboxes, tgt_gt_inds
targets = outs[1:4]
return rois, rois_num, targets
@register
class BBoxLibraAssigner(object):
__shared__ = ['num_classes']
"""
Libra-RCNN targets assignment module
The assignment consists of three steps:
1. Match RoIs and ground-truth box, label the RoIs with foreground
or background sample
2. Sample anchors to keep the properly ratio between foreground and
background
3. Generate the targets for classification and regression branch
Args:
batch_size_per_im (int): Total number of RoIs per image.
default 512
fg_fraction (float): Fraction of RoIs that is labeled
foreground, default 0.25
fg_thresh (float): Minimum overlap required between a RoI
and ground-truth box for the (roi, gt box) pair to be
a foreground sample. default 0.5
bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5
use_random (bool): Use random sampling to choose foreground and
background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and
background of each stage, which is only used In Cascade RCNN.
num_classes (int): The number of class.
num_bins (int): The number of libra_sample.
"""
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=.5,
bg_thresh=.5,
use_random=True,
cascade_iou=[0.5, 0.6, 0.7],
num_classes=80,
num_bins=3):
super(BBoxLibraAssigner, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh
self.use_random = use_random
self.cascade_iou = cascade_iou
self.num_classes = num_classes
self.num_bins = num_bins
def __call__(self,
rpn_rois,
rpn_rois_num,
inputs,
stage=0,
is_cascade=False):
gt_classes = inputs['gt_class']
gt_boxes = inputs['gt_bbox']
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds
outs = libra_generate_proposal_target(
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
self.use_random, is_cascade, self.cascade_iou[stage], self.num_bins)
rois = outs[0]
rois_num = outs[-1]
# tgt_labels, tgt_bboxes, tgt_gt_inds
targets = outs[1:4]
return rois, rois_num, targets
@register
@serializable
class MaskAssigner(object):
__shared__ = ['num_classes', 'mask_resolution']
"""
Mask targets assignment module
The assignment consists of three steps:
1. Select RoIs labels with foreground.
2. Encode the RoIs and corresponding gt polygons to generate
mask target
Args:
num_classes (int): The number of class
mask_resolution (int): The resolution of mask target, default 14
"""
def __init__(self, num_classes=80, mask_resolution=14):
super(MaskAssigner, self).__init__()
self.num_classes = num_classes
self.mask_resolution = mask_resolution
def __call__(self, rois, tgt_labels, tgt_gt_inds, inputs):
gt_segms = inputs['gt_poly']
outs = generate_mask_target(gt_segms, rois, tgt_labels, tgt_gt_inds,
self.num_classes, self.mask_resolution)
# mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
return outs
@register
class RBoxAssigner(object):
"""
assigner of rbox
Args:
pos_iou_thr (float): threshold of pos samples
neg_iou_thr (float): threshold of neg samples
min_iou_thr (float): the min threshold of samples
ignore_iof_thr (int): the ignored threshold
"""
def __init__(self,
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_iou_thr=0.0,
ignore_iof_thr=-2):
super(RBoxAssigner, self).__init__()
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_iou_thr = min_iou_thr
self.ignore_iof_thr = ignore_iof_thr
def anchor_valid(self, anchors):
"""
Args:
anchor: M x 4
Returns:
"""
if anchors.ndim == 3:
anchors = anchors.reshape(-1, anchors.shape[-1])
assert anchors.ndim == 2
anchor_num = anchors.shape[0]
anchor_valid = np.ones((anchor_num), np.int32)
anchor_inds = np.arange(anchor_num)
return anchor_inds
def rbox2delta(self,
proposals,
gt,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1]):
"""
Args:
proposals: tensor [N, 5]
gt: gt [N, 5]
means: means [5]
stds: stds [5]
Returns:
"""
proposals = proposals.astype(np.float64)
PI = np.pi
gt_widths = gt[..., 2]
gt_heights = gt[..., 3]
gt_angle = gt[..., 4]
proposals_widths = proposals[..., 2]
proposals_heights = proposals[..., 3]
proposals_angle = proposals[..., 4]
coord = gt[..., 0:2] - proposals[..., 0:2]
dx = (np.cos(proposals[..., 4]) * coord[..., 0] +
np.sin(proposals[..., 4]) * coord[..., 1]) / proposals_widths
dy = (-np.sin(proposals[..., 4]) * coord[..., 0] +
np.cos(proposals[..., 4]) * coord[..., 1]) / proposals_heights
dw = np.log(gt_widths / proposals_widths)
dh = np.log(gt_heights / proposals_heights)
da = (gt_angle - proposals_angle)
da = (da + PI / 4) % PI - PI / 4
da /= PI
deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
means = np.array(means, dtype=deltas.dtype)
stds = np.array(stds, dtype=deltas.dtype)
deltas = (deltas - means) / stds
deltas = deltas.astype(np.float32)
return deltas
def assign_anchor(self,
anchors,
gt_bboxes,
gt_labels,
pos_iou_thr,
neg_iou_thr,
min_iou_thr=0.0,
ignore_iof_thr=-2):
assert anchors.shape[1] == 4 or anchors.shape[1] == 5
assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5
anchors_xc_yc = anchors
gt_bboxes_xc_yc = gt_bboxes
# calc rbox iou
anchors_xc_yc = anchors_xc_yc.astype(np.float32)
gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32)
anchors_xc_yc = paddle.to_tensor(anchors_xc_yc)
gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc)
try:
from ext_op import rbox_iou
except Exception as e:
print("import custom_ops error, try install ext_op " \
"following ppdet/ext_op/README.md", e)
sys.stdout.flush()
sys.exit(-1)
iou = rbox_iou(gt_bboxes_xc_yc, anchors_xc_yc)
iou = iou.numpy()
iou = iou.T
# every gt's anchor's index
gt_bbox_anchor_inds = iou.argmax(axis=0)
gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]
# every anchor's gt bbox's index
anchor_gt_bbox_inds = iou.argmax(axis=1)
anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]
# (1) set labels=-2 as default
labels = np.ones((iou.shape[0], ), dtype=np.int32) * ignore_iof_thr
# (2) assign ignore
labels[anchor_gt_bbox_iou < min_iou_thr] = ignore_iof_thr
# (3) assign neg_ids -1
assign_neg_ids1 = anchor_gt_bbox_iou >= min_iou_thr
assign_neg_ids2 = anchor_gt_bbox_iou < neg_iou_thr
assign_neg_ids = np.logical_and(assign_neg_ids1, assign_neg_ids2)
labels[assign_neg_ids] = -1
# anchor_gt_bbox_iou_inds
# (4) assign max_iou as pos_ids >=0
anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds]
# gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr)
labels[gt_bbox_anchor_iou_inds] = gt_labels[anchor_gt_bbox_iou_inds]
# (5) assign >= pos_iou_thr as pos_ids
iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr
iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids]
labels[iou_pos_iou_thr_ids] = gt_labels[iou_pos_iou_thr_ids_box_inds]
return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd):
assert anchors.ndim == 2
assert anchors.shape[1] == 5
assert gt_bboxes.ndim == 2
assert gt_bboxes.shape[1] == 5
pos_iou_thr = self.pos_iou_thr
neg_iou_thr = self.neg_iou_thr
min_iou_thr = self.min_iou_thr
ignore_iof_thr = self.ignore_iof_thr
anchor_num = anchors.shape[0]
gt_bboxes = gt_bboxes
is_crowd_slice = is_crowd
not_crowd_inds = np.where(is_crowd_slice == 0)
# Step1: match anchor and gt_bbox
anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = self.assign_anchor(
anchors, gt_bboxes,
gt_labels.reshape(-1), pos_iou_thr, neg_iou_thr, min_iou_thr,
ignore_iof_thr)
# Step2: sample anchor
pos_inds = np.where(labels >= 0)[0]
neg_inds = np.where(labels == -1)[0]
# Step3: make output
anchors_num = anchors.shape[0]
bbox_targets = np.zeros_like(anchors)
bbox_weights = np.zeros_like(anchors)
bbox_gt_bboxes = np.zeros_like(anchors)
pos_labels = np.zeros(anchors_num, dtype=np.int32)
pos_labels_weights = np.zeros(anchors_num, dtype=np.float32)
pos_sampled_anchors = anchors[pos_inds]
pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]]
if len(pos_inds) > 0:
pos_bbox_targets = self.rbox2delta(pos_sampled_anchors,
pos_sampled_gt_boxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_gt_bboxes[pos_inds, :] = pos_sampled_gt_boxes
bbox_weights[pos_inds, :] = 1.0
pos_labels[pos_inds] = labels[pos_inds]
pos_labels_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
pos_labels_weights[neg_inds] = 1.0
return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights,
bbox_gt_bboxes, pos_inds, neg_inds)