Files
fcb_photo_review/paddle_detection/ppdet/modeling/losses/fcos_loss.py
2024-08-27 14:42:45 +08:00

1021 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling import ops
from functools import partial
__all__ = ['FCOSLoss', 'FCOSLossMILC', 'FCOSLossCR']
def flatten_tensor(inputs, channel_first=False):
"""
Flatten a Tensor
Args:
inputs (Tensor): 4-D Tensor with shape [N, C, H, W] or [N, H, W, C]
channel_first (bool): If true the dimension order of Tensor is
[N, C, H, W], otherwise is [N, H, W, C]
Return:
output_channel_last (Tensor): The flattened Tensor in channel_last style
"""
if channel_first:
input_channel_last = paddle.transpose(inputs, perm=[0, 2, 3, 1])
else:
input_channel_last = inputs
output_channel_last = paddle.flatten(
input_channel_last, start_axis=0, stop_axis=2)
return output_channel_last
@register
class FCOSLoss(nn.Layer):
"""
FCOSLoss
Args:
loss_alpha (float): alpha in focal loss
loss_gamma (float): gamma in focal loss
iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights (float): weight for location loss
quality (str): quality branch, centerness/iou
"""
def __init__(self,
loss_alpha=0.25,
loss_gamma=2.0,
iou_loss_type="giou",
reg_weights=1.0,
quality='centerness'):
super(FCOSLoss, self).__init__()
self.loss_alpha = loss_alpha
self.loss_gamma = loss_gamma
self.iou_loss_type = iou_loss_type
self.reg_weights = reg_weights
self.quality = quality
def _iou_loss(self,
pred,
targets,
positive_mask,
weights=None,
return_iou=False):
"""
Calculate the loss for location prediction
Args:
pred (Tensor): bounding boxes prediction
targets (Tensor): targets for positive samples
positive_mask (Tensor): mask of positive samples
weights (Tensor): weights for each positive samples
Return:
loss (Tensor): location loss
"""
plw = pred[:, 0] * positive_mask
pth = pred[:, 1] * positive_mask
prw = pred[:, 2] * positive_mask
pbh = pred[:, 3] * positive_mask
tlw = targets[:, 0] * positive_mask
tth = targets[:, 1] * positive_mask
trw = targets[:, 2] * positive_mask
tbh = targets[:, 3] * positive_mask
tlw.stop_gradient = True
trw.stop_gradient = True
tth.stop_gradient = True
tbh.stop_gradient = True
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
clw = paddle.maximum(plw, tlw)
crw = paddle.maximum(prw, trw)
cth = paddle.maximum(pth, tth)
cbh = paddle.maximum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
ious = ious * positive_mask
if return_iou:
return ious
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
area_uniou = area_predict + area_target - area_inter
area_circum = (clw + crw) * (cth + cbh) + 1e-7
giou = ious - (area_circum - area_uniou) / area_circum
loss = 1.0 - giou
elif self.iou_loss_type.lower() == "iou":
loss = 0.0 - paddle.log(ious)
else:
raise KeyError
if weights is not None:
loss = loss * weights
return loss
def forward(self, cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_center):
"""
Calculate the loss for classification, location and centerness
Args:
cls_logits (list): list of Tensor, which is predicted
score for all anchor points with shape [N, M, C]
bboxes_reg (list): list of Tensor, which is predicted
offsets for all anchor points with shape [N, M, 4]
centerness (list): list of Tensor, which is predicted
centerness for all anchor points with shape [N, M, 1]
tag_labels (list): list of Tensor, which is category
targets for each anchor point
tag_bboxes (list): list of Tensor, which is bounding
boxes targets for positive samples
tag_center (list): list of Tensor, which is centerness
targets for positive samples
Return:
loss (dict): loss composed by classification loss, bounding box
"""
cls_logits_flatten_list = []
bboxes_reg_flatten_list = []
centerness_flatten_list = []
tag_labels_flatten_list = []
tag_bboxes_flatten_list = []
tag_center_flatten_list = []
num_lvl = len(cls_logits)
for lvl in range(num_lvl):
cls_logits_flatten_list.append(
flatten_tensor(cls_logits[lvl], True))
bboxes_reg_flatten_list.append(
flatten_tensor(bboxes_reg[lvl], True))
centerness_flatten_list.append(
flatten_tensor(centerness[lvl], True))
tag_labels_flatten_list.append(
flatten_tensor(tag_labels[lvl], False))
tag_bboxes_flatten_list.append(
flatten_tensor(tag_bboxes[lvl], False))
tag_center_flatten_list.append(
flatten_tensor(tag_center[lvl], False))
cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0)
bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0)
centerness_flatten = paddle.concat(centerness_flatten_list, axis=0)
tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0)
tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0)
tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0)
tag_labels_flatten.stop_gradient = True
tag_bboxes_flatten.stop_gradient = True
tag_center_flatten.stop_gradient = True
mask_positive_bool = tag_labels_flatten > 0
mask_positive_bool.stop_gradient = True
mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
mask_positive_float.stop_gradient = True
num_positive_fp32 = paddle.sum(mask_positive_float)
num_positive_fp32.stop_gradient = True
num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
num_positive_int32 = num_positive_int32 * 0 + 1
num_positive_int32.stop_gradient = True
normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
normalize_sum.stop_gradient = True
# 1. cls_logits: sigmoid_focal_loss
# expand onehot labels
num_classes = cls_logits_flatten.shape[-1]
tag_labels_flatten = paddle.squeeze(tag_labels_flatten, axis=-1)
tag_labels_flatten_bin = F.one_hot(
tag_labels_flatten, num_classes=1 + num_classes)
tag_labels_flatten_bin = tag_labels_flatten_bin[:, 1:]
# sigmoid_focal_loss
cls_loss = F.sigmoid_focal_loss(
cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32
if self.quality == 'centerness':
# 2. bboxes_reg: giou_loss
mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
reg_loss = self._iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=tag_center_flatten)
reg_loss = reg_loss * mask_positive_float / normalize_sum
# 3. centerness: sigmoid_cross_entropy_with_logits_loss
centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
quality_loss = ops.sigmoid_cross_entropy_with_logits(
centerness_flatten, tag_center_flatten)
quality_loss = quality_loss * mask_positive_float / num_positive_fp32
elif self.quality == 'iou':
# 2. bboxes_reg: giou_loss
mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
reg_loss = self._iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=None)
reg_loss = reg_loss * mask_positive_float / num_positive_fp32
# num_positive_fp32 is num_foreground
# 3. centerness: sigmoid_cross_entropy_with_logits_loss
centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
gt_ious = self._iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=None,
return_iou=True)
quality_loss = ops.sigmoid_cross_entropy_with_logits(
centerness_flatten, gt_ious)
quality_loss = quality_loss * mask_positive_float / num_positive_fp32
else:
raise Exception(f'Unknown quality type: {self.quality}')
loss_all = {
"loss_cls": paddle.sum(cls_loss),
"loss_box": paddle.sum(reg_loss),
"loss_quality": paddle.sum(quality_loss),
}
return loss_all
@register
class FCOSLossMILC(FCOSLoss):
"""
FCOSLossMILC for ARSL in semi-det(ssod)
Args:
loss_alpha (float): alpha in focal loss
loss_gamma (float): gamma in focal loss
iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights (float): weight for location loss
"""
def __init__(self,
loss_alpha=0.25,
loss_gamma=2.0,
iou_loss_type="giou",
reg_weights=1.0):
super(FCOSLossMILC, self).__init__()
self.loss_alpha = loss_alpha
self.loss_gamma = loss_gamma
self.iou_loss_type = iou_loss_type
self.reg_weights = reg_weights
def iou_loss(self, pred, targets, weights=None, avg_factor=None):
"""
Calculate the loss for location prediction
Args:
pred (Tensor): bounding boxes prediction
targets (Tensor): targets for positive samples
weights (Tensor): weights for each positive samples
Return:
loss (Tensor): location loss
"""
plw = pred[:, 0]
pth = pred[:, 1]
prw = pred[:, 2]
pbh = pred[:, 3]
tlw = targets[:, 0]
tth = targets[:, 1]
trw = targets[:, 2]
tbh = targets[:, 3]
tlw.stop_gradient = True
trw.stop_gradient = True
tth.stop_gradient = True
tbh.stop_gradient = True
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
clw = paddle.maximum(plw, tlw)
crw = paddle.maximum(prw, trw)
cth = paddle.maximum(pth, tth)
cbh = paddle.maximum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
ious = ious
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
area_uniou = area_predict + area_target - area_inter
area_circum = (clw + crw) * (cth + cbh) + 1e-7
giou = ious - (area_circum - area_uniou) / area_circum
loss = 1.0 - giou
elif self.iou_loss_type.lower() == "iou":
loss = 0.0 - paddle.log(ious)
else:
raise KeyError
if weights is not None:
loss = loss * weights
loss = paddle.sum(loss)
if avg_factor is not None:
loss = loss / avg_factor
return loss
# temp function: calcualate iou between bbox and target
def _bbox_overlap_align(self, pred, targets):
assert pred.shape[0] == targets.shape[0], \
'the pred should be aligned with target.'
plw = pred[:, 0]
pth = pred[:, 1]
prw = pred[:, 2]
pbh = pred[:, 3]
tlw = targets[:, 0]
tth = targets[:, 1]
trw = targets[:, 2]
tbh = targets[:, 3]
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
return ious
def iou_based_soft_label_loss(self,
pred,
target,
alpha=0.75,
gamma=2.0,
iou_weighted=False,
implicit_iou=None,
avg_factor=None):
assert pred.shape == target.shape
pred = F.sigmoid(pred)
target = target.cast(pred.dtype)
if implicit_iou is not None:
pred = pred * implicit_iou
if iou_weighted:
focal_weight = (pred - target).abs().pow(gamma) * target * (target > 0.0).cast('float32') + \
alpha * (pred - target).abs().pow(gamma) * \
(target <= 0.0).cast('float32')
else:
focal_weight = (pred - target).abs().pow(gamma) * (target > 0.0).cast('float32') + \
alpha * (pred - target).abs().pow(gamma) * \
(target <= 0.0).cast('float32')
# focal loss
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
if avg_factor is not None:
loss = loss / avg_factor
return loss
def forward(self, cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_center):
"""
Calculate the loss for classification, location and centerness
Args:
cls_logits (list): list of Tensor, which is predicted
score for all anchor points with shape [N, M, C]
bboxes_reg (list): list of Tensor, which is predicted
offsets for all anchor points with shape [N, M, 4]
centerness (list): list of Tensor, which is predicted
centerness for all anchor points with shape [N, M, 1]
tag_labels (list): list of Tensor, which is category
targets for each anchor point
tag_bboxes (list): list of Tensor, which is bounding
boxes targets for positive samples
tag_center (list): list of Tensor, which is centerness
targets for positive samples
Return:
loss (dict): loss composed by classification loss, bounding box
"""
cls_logits_flatten_list = []
bboxes_reg_flatten_list = []
centerness_flatten_list = []
tag_labels_flatten_list = []
tag_bboxes_flatten_list = []
tag_center_flatten_list = []
num_lvl = len(cls_logits)
for lvl in range(num_lvl):
cls_logits_flatten_list.append(
flatten_tensor(cls_logits[lvl], True))
bboxes_reg_flatten_list.append(
flatten_tensor(bboxes_reg[lvl], True))
centerness_flatten_list.append(
flatten_tensor(centerness[lvl], True))
tag_labels_flatten_list.append(
flatten_tensor(tag_labels[lvl], False))
tag_bboxes_flatten_list.append(
flatten_tensor(tag_bboxes[lvl], False))
tag_center_flatten_list.append(
flatten_tensor(tag_center[lvl], False))
cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0)
bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0)
centerness_flatten = paddle.concat(centerness_flatten_list, axis=0)
tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0)
tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0)
tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0)
tag_labels_flatten.stop_gradient = True
tag_bboxes_flatten.stop_gradient = True
tag_center_flatten.stop_gradient = True
# find positive index
mask_positive_bool = tag_labels_flatten > 0
mask_positive_bool.stop_gradient = True
mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
mask_positive_float.stop_gradient = True
num_positive_fp32 = paddle.sum(mask_positive_float)
num_positive_fp32.stop_gradient = True
num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
num_positive_int32 = num_positive_int32 * 0 + 1
num_positive_int32.stop_gradient = True
# centerness target is used as reg weight
normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
normalize_sum.stop_gradient = True
# 1. IoU-Based soft label loss
# calculate iou
with paddle.no_grad():
pos_ind = paddle.nonzero(
tag_labels_flatten.reshape([-1]) > 0).reshape([-1])
pos_pred = bboxes_reg_flatten[pos_ind]
pos_target = tag_bboxes_flatten[pos_ind]
bbox_iou = self._bbox_overlap_align(pos_pred, pos_target)
# pos labels
pos_labels = tag_labels_flatten[pos_ind].squeeze(1)
cls_target = paddle.zeros(cls_logits_flatten.shape)
cls_target[pos_ind, pos_labels - 1] = bbox_iou
cls_loss = self.iou_based_soft_label_loss(
cls_logits_flatten,
cls_target,
implicit_iou=F.sigmoid(centerness_flatten),
avg_factor=num_positive_fp32)
# 2. bboxes_reg: giou_loss
mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
reg_loss = self._iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=tag_center_flatten)
reg_loss = reg_loss * mask_positive_float / normalize_sum
# 3. iou loss
pos_iou_pred = paddle.squeeze(centerness_flatten, axis=-1)[pos_ind]
loss_iou = ops.sigmoid_cross_entropy_with_logits(pos_iou_pred, bbox_iou)
loss_iou = loss_iou / num_positive_fp32 * 0.5
loss_all = {
"loss_cls": paddle.sum(cls_loss),
"loss_box": paddle.sum(reg_loss),
'loss_iou': paddle.sum(loss_iou),
}
return loss_all
# Concat multi-level feature maps by image
def levels_to_images(mlvl_tensor):
batch_size = mlvl_tensor[0].shape[0]
batch_list = [[] for _ in range(batch_size)]
channels = mlvl_tensor[0].shape[1]
for t in mlvl_tensor:
t = t.transpose([0, 2, 3, 1])
t = t.reshape([batch_size, -1, channels])
for img in range(batch_size):
batch_list[img].append(t[img])
return [paddle.concat(item, axis=0) for item in batch_list]
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
@register
class FCOSLossCR(FCOSLossMILC):
"""
FCOSLoss of Consistency Regularization
"""
def __init__(self,
iou_loss_type="giou",
cls_weight=2.0,
reg_weight=2.0,
iou_weight=0.5,
hard_neg_mining_flag=True):
super(FCOSLossCR, self).__init__()
self.iou_loss_type = iou_loss_type
self.cls_weight = cls_weight
self.reg_weight = reg_weight
self.iou_weight = iou_weight
self.hard_neg_mining_flag = hard_neg_mining_flag
def iou_loss(self, pred, targets, weights=None, avg_factor=None):
"""
Calculate the loss for location prediction
Args:
pred (Tensor): bounding boxes prediction
targets (Tensor): targets for positive samples
weights (Tensor): weights for each positive samples
Return:
loss (Tensor): location loss
"""
plw = pred[:, 0]
pth = pred[:, 1]
prw = pred[:, 2]
pbh = pred[:, 3]
tlw = targets[:, 0]
tth = targets[:, 1]
trw = targets[:, 2]
tbh = targets[:, 3]
tlw.stop_gradient = True
trw.stop_gradient = True
tth.stop_gradient = True
tbh.stop_gradient = True
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
clw = paddle.maximum(plw, tlw)
crw = paddle.maximum(prw, trw)
cth = paddle.maximum(pth, tth)
cbh = paddle.maximum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
ious = ious
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
area_uniou = area_predict + area_target - area_inter
area_circum = (clw + crw) * (cth + cbh) + 1e-7
giou = ious - (area_circum - area_uniou) / area_circum
loss = 1.0 - giou
elif self.iou_loss_type.lower() == "iou":
loss = 0.0 - paddle.log(ious)
else:
raise KeyError
if weights is not None:
loss = loss * weights
loss = paddle.sum(loss)
if avg_factor is not None:
loss = loss / avg_factor
return loss
# calcualate iou between bbox and target
def bbox_overlap_align(self, pred, targets):
assert pred.shape[0] == targets.shape[0], \
'the pred should be aligned with target.'
plw = pred[:, 0]
pth = pred[:, 1]
prw = pred[:, 2]
pbh = pred[:, 3]
tlw = targets[:, 0]
tth = targets[:, 1]
trw = targets[:, 2]
tbh = targets[:, 3]
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
return ious
# cls loss: iou-based soft lable with joint iou
def quality_focal_loss(self,
stu_cls,
targets,
quality=None,
weights=None,
alpha=0.75,
gamma=2.0,
avg_factor='sum'):
stu_cls = F.sigmoid(stu_cls)
if quality is not None:
stu_cls = stu_cls * F.sigmoid(quality)
focal_weight = (stu_cls - targets).abs().pow(gamma) * (targets > 0.0).cast('float32') + \
alpha * (stu_cls - targets).abs().pow(gamma) * \
(targets <= 0.0).cast('float32')
loss = F.binary_cross_entropy(
stu_cls, targets, reduction='none') * focal_weight
if weights is not None:
loss = loss * weights.reshape([-1, 1])
loss = paddle.sum(loss)
if avg_factor is not None:
loss = loss / avg_factor
return loss
# generate points according to feature maps
def compute_locations_by_level(self, fpn_stride, h, w):
"""
Compute locations of anchor points of each FPN layer
Return:
Anchor points locations of current FPN feature map
"""
shift_x = paddle.arange(0, w * fpn_stride, fpn_stride)
shift_y = paddle.arange(0, h * fpn_stride, fpn_stride)
shift_x = paddle.unsqueeze(shift_x, axis=0)
shift_y = paddle.unsqueeze(shift_y, axis=1)
shift_x = paddle.expand(shift_x, shape=[h, w])
shift_y = paddle.expand(shift_y, shape=[h, w])
shift_x = paddle.reshape(shift_x, shape=[-1])
shift_y = paddle.reshape(shift_y, shape=[-1])
location = paddle.stack(
[shift_x, shift_y], axis=-1) + float(fpn_stride) / 2
return location
# decode bbox from ltrb to x1y1x2y2
def decode_bbox(self, ltrb, points):
assert ltrb.shape[0] == points.shape[0], \
"When decoding bbox in one image, the num of loc should be same with points."
bbox_decoding = paddle.stack(
[
points[:, 0] - ltrb[:, 0], points[:, 1] - ltrb[:, 1],
points[:, 0] + ltrb[:, 2], points[:, 1] + ltrb[:, 3]
],
axis=1)
return bbox_decoding
# encode bbox from x1y1x2y2 to ltrb
def encode_bbox(self, bbox, points):
assert bbox.shape[0] == points.shape[0], \
"When encoding bbox in one image, the num of bbox should be same with points."
bbox_encoding = paddle.stack(
[
points[:, 0] - bbox[:, 0], points[:, 1] - bbox[:, 1],
bbox[:, 2] - points[:, 0], bbox[:, 3] - points[:, 1]
],
axis=1)
return bbox_encoding
def calcualate_iou(self, gt_bbox, predict_bbox):
# bbox area
gt_area = (gt_bbox[:, 2] - gt_bbox[:, 0]) * \
(gt_bbox[:, 3] - gt_bbox[:, 1])
predict_area = (predict_bbox[:, 2] - predict_bbox[:, 0]) * \
(predict_bbox[:, 3] - predict_bbox[:, 1])
# overlop area
lt = paddle.fmax(gt_bbox[:, None, :2], predict_bbox[None, :, :2])
rb = paddle.fmin(gt_bbox[:, None, 2:], predict_bbox[None, :, 2:])
wh = paddle.clip(rb - lt, min=0)
overlap = wh[..., 0] * wh[..., 1]
# iou
iou = overlap / (gt_area[:, None] + predict_area[None, :] - overlap)
return iou
# select potential positives from hard negatives
def hard_neg_mining(self,
cls_score,
loc_ltrb,
quality,
pos_ind,
hard_neg_ind,
loc_mask,
loc_targets,
iou_thresh=0.6):
# get points locations and strides
points_list = []
strides_list = []
scale_list = []
scale = [0, 1, 2, 3, 4]
for fpn_scale, fpn_stride, HW in zip(scale, self.fpn_stride,
self.lvl_hw):
h, w = HW
lvl_points = self.compute_locations_by_level(fpn_stride, h, w)
points_list.append(lvl_points)
lvl_strides = paddle.full([h * w, 1], fpn_stride)
strides_list.append(lvl_strides)
lvl_scales = paddle.full([h * w, 1], fpn_scale)
scale_list.append(lvl_scales)
points = paddle.concat(points_list, axis=0)
strides = paddle.concat(strides_list, axis=0)
scales = paddle.concat(scale_list, axis=0)
# cls scores
cls_vals = F.sigmoid(cls_score) * F.sigmoid(quality)
max_vals = paddle.max(cls_vals, axis=-1)
class_ind = paddle.argmax(cls_vals, axis=-1)
### calculate iou between positive and hard negative
# decode pos bbox
pos_cls = max_vals[pos_ind]
pos_loc = loc_ltrb[pos_ind].reshape([-1, 4])
pos_strides = strides[pos_ind]
pos_points = points[pos_ind].reshape([-1, 2])
pos_loc = pos_loc * pos_strides
pos_bbox = self.decode_bbox(pos_loc, pos_points)
pos_scales = scales[pos_ind]
# decode hard negative bbox
hard_neg_loc = loc_ltrb[hard_neg_ind].reshape([-1, 4])
hard_neg_strides = strides[hard_neg_ind]
hard_neg_points = points[hard_neg_ind].reshape([-1, 2])
hard_neg_loc = hard_neg_loc * hard_neg_strides
hard_neg_bbox = self.decode_bbox(hard_neg_loc, hard_neg_points)
hard_neg_scales = scales[hard_neg_ind]
# iou between pos bbox and hard negative bbox
hard_neg_pos_iou = self.calcualate_iou(hard_neg_bbox, pos_bbox)
### select potential positives from hard negatives
# scale flag
scale_temp = paddle.abs(
pos_scales.reshape([-1])[None, :] - hard_neg_scales.reshape([-1])
[:, None])
scale_flag = (scale_temp <= 1.)
# iou flag
iou_flag = (hard_neg_pos_iou >= iou_thresh)
# same class flag
pos_class = class_ind[pos_ind]
hard_neg_class = class_ind[hard_neg_ind]
class_flag = pos_class[None, :] - hard_neg_class[:, None]
class_flag = (class_flag == 0)
# hard negative point inside positive bbox flag
ltrb_temp = paddle.stack(
[
hard_neg_points[:, None, 0] - pos_bbox[None, :, 0],
hard_neg_points[:, None, 1] - pos_bbox[None, :, 1],
pos_bbox[None, :, 2] - hard_neg_points[:, None, 0],
pos_bbox[None, :, 3] - hard_neg_points[:, None, 1]
],
axis=-1)
inside_flag = ltrb_temp.min(axis=-1) > 0
# reset iou
valid_flag = (iou_flag & class_flag & inside_flag & scale_flag)
invalid_iou = paddle.zeros_like(hard_neg_pos_iou)
hard_neg_pos_iou = paddle.where(valid_flag, hard_neg_pos_iou,
invalid_iou)
pos_hard_neg_max_iou = hard_neg_pos_iou.max(axis=-1)
# selece potential pos
potential_pos_ind = (pos_hard_neg_max_iou > 0.)
num_potential_pos = paddle.nonzero(potential_pos_ind).shape[0]
if num_potential_pos == 0:
return None
### calculate loc targetaggregate all matching bboxes as the bbox targets of potential pos
# prepare data
potential_points = hard_neg_points[potential_pos_ind].reshape([-1, 2])
potential_strides = hard_neg_strides[potential_pos_ind]
potential_valid_flag = valid_flag[potential_pos_ind]
potential_pos_ind = hard_neg_ind[potential_pos_ind]
# get cls and box of matching positives
pos_cls = max_vals[pos_ind]
expand_pos_bbox = paddle.expand(
pos_bbox,
shape=[num_potential_pos, pos_bbox.shape[0], pos_bbox.shape[1]])
expand_pos_cls = paddle.expand(
pos_cls, shape=[num_potential_pos, pos_cls.shape[0]])
invalid_cls = paddle.zeros_like(expand_pos_cls)
expand_pos_cls = paddle.where(potential_valid_flag, expand_pos_cls,
invalid_cls)
expand_pos_cls = paddle.unsqueeze(expand_pos_cls, axis=-1)
# aggregate box based on cls_score
agg_bbox = (expand_pos_bbox * expand_pos_cls).sum(axis=1) \
/ expand_pos_cls.sum(axis=1)
agg_ltrb = self.encode_bbox(agg_bbox, potential_points)
agg_ltrb = agg_ltrb / potential_strides
# loc target for all pos
loc_targets[potential_pos_ind] = agg_ltrb
loc_mask[potential_pos_ind] = 1.
return loc_mask, loc_targets
# get training targets
def get_targets_per_img(self, tea_cls, tea_loc, tea_iou, stu_cls, stu_loc,
stu_iou):
### sample selection
# prepare datas
tea_cls_scores = F.sigmoid(tea_cls) * F.sigmoid(tea_iou)
class_ind = paddle.argmax(tea_cls_scores, axis=-1)
max_vals = paddle.max(tea_cls_scores, axis=-1)
cls_mask = paddle.zeros_like(
max_vals
) # set cls valid mask: pos is 1, hard_negative and negative are 0.
num_pos, num_hard_neg = 0, 0
# mean-std selection
# using nonzero to turn index from bool to int, because the index will be used to compose two-dim index in following.
# using squeeze rather than reshape to avoid errors when no score is larger than thresh.
candidate_ind = paddle.nonzero(max_vals >= 0.1).squeeze(axis=-1)
num_candidate = candidate_ind.shape[0]
if num_candidate > 0:
# pos thresh = mean + std to select pos samples
candidate_score = max_vals[candidate_ind]
candidate_score_mean = candidate_score.mean()
candidate_score_std = candidate_score.std()
pos_thresh = (candidate_score_mean + candidate_score_std).clip(
max=0.4)
# select pos
pos_ind = paddle.nonzero(max_vals >= pos_thresh).squeeze(axis=-1)
num_pos = pos_ind.shape[0]
# select hard negatives as potential pos
hard_neg_ind = (max_vals >= 0.1) & (max_vals < pos_thresh)
hard_neg_ind = paddle.nonzero(hard_neg_ind).squeeze(axis=-1)
num_hard_neg = hard_neg_ind.shape[0]
# if not positive, directly select top-10 as pos.
if (num_pos == 0):
num_pos = 10
_, pos_ind = paddle.topk(max_vals, k=num_pos)
cls_mask[pos_ind] = 1.
### Consistency Regularization Training targets
# cls targets
pos_class_ind = class_ind[pos_ind]
cls_targets = paddle.zeros_like(tea_cls)
cls_targets[pos_ind, pos_class_ind] = tea_cls_scores[pos_ind,
pos_class_ind]
# hard negative cls target
if num_hard_neg != 0:
cls_targets[hard_neg_ind] = tea_cls_scores[hard_neg_ind]
# loc targets
loc_targets = paddle.zeros_like(tea_loc)
loc_targets[pos_ind] = tea_loc[pos_ind]
# iou targets
iou_targets = paddle.zeros(
shape=[tea_iou.shape[0]], dtype=tea_iou.dtype)
iou_targets[pos_ind] = F.sigmoid(
paddle.squeeze(
tea_iou, axis=-1)[pos_ind])
loc_mask = cls_mask.clone()
# select potential positive from hard negatives for loc_task training
if (num_hard_neg > 0) and self.hard_neg_mining_flag:
results = self.hard_neg_mining(tea_cls, tea_loc, tea_iou, pos_ind,
hard_neg_ind, loc_mask, loc_targets)
if results is not None:
loc_mask, loc_targets = results
loc_pos_ind = paddle.nonzero(loc_mask > 0.).squeeze(axis=-1)
iou_targets[loc_pos_ind] = F.sigmoid(
paddle.squeeze(
tea_iou, axis=-1)[loc_pos_ind])
return cls_mask, loc_mask, \
cls_targets, loc_targets, iou_targets
def forward(self, student_prediction, teacher_prediction):
stu_cls_lvl, stu_loc_lvl, stu_iou_lvl = student_prediction
tea_cls_lvl, tea_loc_lvl, tea_iou_lvl, self.fpn_stride = teacher_prediction
# H and W of level (used for aggregating targets)
self.lvl_hw = []
for t in tea_cls_lvl:
_, _, H, W = t.shape
self.lvl_hw.append([H, W])
# levels to images
stu_cls_img = levels_to_images(stu_cls_lvl)
stu_loc_img = levels_to_images(stu_loc_lvl)
stu_iou_img = levels_to_images(stu_iou_lvl)
tea_cls_img = levels_to_images(tea_cls_lvl)
tea_loc_img = levels_to_images(tea_loc_lvl)
tea_iou_img = levels_to_images(tea_iou_lvl)
with paddle.no_grad():
cls_mask, loc_mask, \
cls_targets, loc_targets, iou_targets = multi_apply(
self.get_targets_per_img,
tea_cls_img,
tea_loc_img,
tea_iou_img,
stu_cls_img,
stu_loc_img,
stu_iou_img
)
# flatten preditction
stu_cls = paddle.concat(stu_cls_img, axis=0)
stu_loc = paddle.concat(stu_loc_img, axis=0)
stu_iou = paddle.concat(stu_iou_img, axis=0)
# flatten targets
cls_mask = paddle.concat(cls_mask, axis=0)
loc_mask = paddle.concat(loc_mask, axis=0)
cls_targets = paddle.concat(cls_targets, axis=0)
loc_targets = paddle.concat(loc_targets, axis=0)
iou_targets = paddle.concat(iou_targets, axis=0)
### Training Weights and avg factor
# find positives
cls_pos_ind = paddle.nonzero(cls_mask > 0.).squeeze(axis=-1)
loc_pos_ind = paddle.nonzero(loc_mask > 0.).squeeze(axis=-1)
# cls weight
cls_sample_weights = paddle.ones([cls_targets.shape[0]])
cls_avg_factor = paddle.max(cls_targets[cls_pos_ind],
axis=-1).sum().item()
# loc weight
loc_sample_weights = paddle.max(cls_targets[loc_pos_ind], axis=-1)
loc_avg_factor = loc_sample_weights.sum().item()
# iou weight
iou_sample_weights = paddle.ones([loc_pos_ind.shape[0]])
iou_avg_factor = loc_pos_ind.shape[0]
### unsupervised loss
# cls loss
loss_cls = self.quality_focal_loss(
stu_cls,
cls_targets,
quality=stu_iou,
weights=cls_sample_weights,
avg_factor=cls_avg_factor) * self.cls_weight
# iou loss
pos_stu_iou = paddle.squeeze(stu_iou, axis=-1)[loc_pos_ind]
pos_iou_targets = iou_targets[loc_pos_ind]
loss_iou = F.binary_cross_entropy(
F.sigmoid(pos_stu_iou), pos_iou_targets,
reduction='none') * iou_sample_weights
loss_iou = loss_iou.sum() / iou_avg_factor * self.iou_weight
# box loss
pos_stu_loc = stu_loc[loc_pos_ind]
pos_loc_targets = loc_targets[loc_pos_ind]
loss_box = self.iou_loss(
pos_stu_loc,
pos_loc_targets,
weights=loc_sample_weights,
avg_factor=loc_avg_factor)
loss_box = loss_box * self.reg_weight
loss_all = {
"loss_cls": loss_cls,
"loss_box": loss_box,
"loss_iou": loss_iou,
}
return loss_all