# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr from ppdet.core.workspace import register from ppdet.modeling import ops from ppdet.modeling.losses.iou_loss import GIoULoss from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) __all__ = [ 'DistillYOLOv3Loss', 'KnowledgeDistillationKLDivLoss', 'DistillPPYOLOELoss', 'FGDFeatureLoss', 'CWDFeatureLoss', 'PKDFeatureLoss', 'MGDFeatureLoss', ] def parameter_init(mode="kaiming", value=0.): if mode == "kaiming": weight_attr = paddle.nn.initializer.KaimingUniform() elif mode == "constant": weight_attr = paddle.nn.initializer.Constant(value=value) else: weight_attr = paddle.nn.initializer.KaimingUniform() weight_init = ParamAttr(initializer=weight_attr) return weight_init def feature_norm(feat): # Normalize the feature maps to have zero mean and unit variances. assert len(feat.shape) == 4 N, C, H, W = feat.shape feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) mean = feat.mean(axis=-1, keepdim=True) std = feat.std(axis=-1, keepdim=True) feat = (feat - mean) / (std + 1e-6) return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) @register class DistillYOLOv3Loss(nn.Layer): def __init__(self, weight=1000): super(DistillYOLOv3Loss, self).__init__() self.loss_weight = weight def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj): loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx)) loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty)) loss_w = paddle.abs(sw - tw) loss_h = paddle.abs(sh - th) loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h]) weighted_loss = paddle.mean(loss * F.sigmoid(tobj)) return weighted_loss def obj_weighted_cls(self, scls, tcls, tobj): loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls)) weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj))) return weighted_loss def obj_loss(self, sobj, tobj): obj_mask = paddle.cast(tobj > 0., dtype="float32") obj_mask.stop_gradient = True loss = paddle.mean( ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask)) return loss def forward(self, teacher_model, student_model): teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs student_distill_pairs = student_model.yolo_head.loss.distill_pairs distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], [] for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs): distill_reg_loss.append( self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[ 3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4])) distill_cls_loss.append( self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4])) distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4])) distill_reg_loss = paddle.add_n(distill_reg_loss) distill_cls_loss = paddle.add_n(distill_cls_loss) distill_obj_loss = paddle.add_n(distill_obj_loss) loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss ) * self.loss_weight return loss @register class KnowledgeDistillationKLDivLoss(nn.Layer): """Loss function for knowledge distilling using KL divergence. Args: reduction (str): Options are `'none'`, `'mean'` and `'sum'`. loss_weight (float): Loss weight of current loss. T (int): Temperature for distillation. """ def __init__(self, reduction='mean', loss_weight=1.0, T=10): super(KnowledgeDistillationKLDivLoss, self).__init__() assert reduction in ('none', 'mean', 'sum') assert T >= 1 self.reduction = reduction self.loss_weight = loss_weight self.T = T def knowledge_distillation_kl_div_loss(self, pred, soft_label, T, detach_target=True): r"""Loss function for knowledge distilling using KL divergence. Args: pred (Tensor): Predicted logits with shape (N, n + 1). soft_label (Tensor): Target logits with shape (N, N + 1). T (int): Temperature for distillation. detach_target (bool): Remove soft_label from automatic differentiation """ assert pred.shape == soft_label.shape target = F.softmax(soft_label / T, axis=1) if detach_target: target = target.detach() kd_loss = F.kl_div( F.log_softmax( pred / T, axis=1), target, reduction='none').mean(1) * (T * T) return kd_loss def forward(self, pred, soft_label, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (Tensor): Predicted logits with shape (N, n + 1). soft_label (Tensor): Target logits with shape (N, N + 1). weight (Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = (reduction_override if reduction_override else self.reduction) loss_kd_out = self.knowledge_distillation_kl_div_loss( pred, soft_label, T=self.T) if weight is not None: loss_kd_out = weight * loss_kd_out if avg_factor is None: if reduction == 'none': loss = loss_kd_out elif reduction == 'mean': loss = loss_kd_out.mean() elif reduction == 'sum': loss = loss_kd_out.sum() else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': loss = loss_kd_out.sum() / avg_factor # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError( 'avg_factor can not be used with reduction="sum"') loss_kd = self.loss_weight * loss return loss_kd @register class DistillPPYOLOELoss(nn.Layer): def __init__( self, loss_weight={'logits': 4.0, 'feat': 1.0}, logits_distill=True, logits_loss_weight={'class': 1.0, 'iou': 2.5, 'dfl': 0.5}, logits_ld_distill=False, logits_ld_params={'weight': 20000, 'T': 10}, feat_distill=True, feat_distiller='fgd', feat_distill_place='neck_feats', teacher_width_mult=1.0, # L student_width_mult=0.75, # M feat_out_channels=[768, 384, 192]): super(DistillPPYOLOELoss, self).__init__() self.loss_weight_logits = loss_weight['logits'] self.loss_weight_feat = loss_weight['feat'] self.logits_distill = logits_distill self.logits_ld_distill = logits_ld_distill self.feat_distill = feat_distill if logits_distill and self.loss_weight_logits > 0: self.bbox_loss_weight = logits_loss_weight['iou'] self.dfl_loss_weight = logits_loss_weight['dfl'] self.qfl_loss_weight = logits_loss_weight['class'] self.loss_bbox = GIoULoss() if logits_ld_distill: self.loss_kd = KnowledgeDistillationKLDivLoss( loss_weight=logits_ld_params['weight'], T=logits_ld_params['T']) if feat_distill and self.loss_weight_feat > 0: assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] assert feat_distill_place in ['backbone_feats', 'neck_feats'] self.feat_distill_place = feat_distill_place self.t_channel_list = [ int(c * teacher_width_mult) for c in feat_out_channels ] self.s_channel_list = [ int(c * student_width_mult) for c in feat_out_channels ] self.distill_feat_loss_modules = [] for i in range(len(feat_out_channels)): if feat_distiller == 'cwd': feat_loss_module = CWDFeatureLoss( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True) elif feat_distiller == 'fgd': feat_loss_module = FGDFeatureLoss( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True, alpha_fgd=0.00001, beta_fgd=0.000005, gamma_fgd=0.00001, lambda_fgd=0.00000005) elif feat_distiller == 'pkd': feat_loss_module = PKDFeatureLoss( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True, resize_stu=True) elif feat_distiller == 'mgd': feat_loss_module = MGDFeatureLoss( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True, loss_func='ssim') elif feat_distiller == 'mimic': feat_loss_module = MimicFeatureLoss( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True) else: raise ValueError self.distill_feat_loss_modules.append(feat_loss_module) def quality_focal_loss(self, pred_logits, soft_target_logits, beta=2.0, use_sigmoid=False, num_total_pos=None): if use_sigmoid: func = F.binary_cross_entropy_with_logits soft_target = F.sigmoid(soft_target_logits) pred_sigmoid = F.sigmoid(pred_logits) preds = pred_logits else: func = F.binary_cross_entropy soft_target = soft_target_logits pred_sigmoid = pred_logits preds = pred_sigmoid scale_factor = pred_sigmoid - soft_target loss = func( preds, soft_target, reduction='none') * scale_factor.abs().pow(beta) loss = loss.sum(1) if num_total_pos is not None: loss = loss.sum() / num_total_pos else: loss = loss.mean() return loss def bbox_loss(self, s_bbox, t_bbox, weight_targets=None): # [x,y,w,h] if weight_targets is not None: loss = paddle.sum(self.loss_bbox(s_bbox, t_bbox) * weight_targets) avg_factor = weight_targets.sum() loss = loss / avg_factor else: loss = paddle.mean(self.loss_bbox(s_bbox, t_bbox)) return loss def distribution_focal_loss(self, pred_corners, target_corners, weight_targets=None): target_corners_label = F.softmax(target_corners, axis=-1) loss_dfl = F.cross_entropy( pred_corners, target_corners_label, soft_label=True, reduction='none') loss_dfl = loss_dfl.sum(1) if weight_targets is not None: loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1])) loss_dfl = loss_dfl.sum(-1) / weight_targets.sum() else: loss_dfl = loss_dfl.mean(-1) return loss_dfl / 4.0 # 4 direction def main_kd(self, mask_positive, pred_scores, soft_cls, num_classes): num_pos = mask_positive.sum() if num_pos > 0: cls_mask = mask_positive.unsqueeze(-1).tile([1, 1, num_classes]) pred_scores_pos = paddle.masked_select( pred_scores, cls_mask).reshape([-1, num_classes]) soft_cls_pos = paddle.masked_select( soft_cls, cls_mask).reshape([-1, num_classes]) loss_kd = self.loss_kd( pred_scores_pos, soft_cls_pos, avg_factor=num_pos) else: loss_kd = paddle.zeros([1]) return loss_kd def forward(self, teacher_model, student_model): teacher_distill_pairs = teacher_model.yolo_head.distill_pairs student_distill_pairs = student_model.yolo_head.distill_pairs if self.logits_distill and self.loss_weight_logits > 0: distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], [] distill_cls_loss.append( self.quality_focal_loss( student_distill_pairs['pred_cls_scores'].reshape( (-1, student_distill_pairs['pred_cls_scores'].shape[-1] )), teacher_distill_pairs['pred_cls_scores'].detach().reshape( (-1, teacher_distill_pairs['pred_cls_scores'].shape[-1] )), num_total_pos=student_distill_pairs['pos_num'], use_sigmoid=False)) distill_bbox_loss.append( self.bbox_loss(student_distill_pairs['pred_bboxes_pos'], teacher_distill_pairs['pred_bboxes_pos'].detach(), weight_targets=student_distill_pairs['bbox_weight'] ) if 'pred_bboxes_pos' in student_distill_pairs and \ 'pred_bboxes_pos' in teacher_distill_pairs and \ 'bbox_weight' in student_distill_pairs else paddle.zeros([1])) distill_dfl_loss.append( self.distribution_focal_loss( student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])), teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \ weight_targets=student_distill_pairs['bbox_weight'] ) if 'pred_dist_pos' in student_distill_pairs and \ 'pred_dist_pos' in teacher_distill_pairs and \ 'bbox_weight' in student_distill_pairs else paddle.zeros([1])) distill_cls_loss = paddle.add_n(distill_cls_loss) distill_bbox_loss = paddle.add_n(distill_bbox_loss) distill_dfl_loss = paddle.add_n(distill_dfl_loss) logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight if self.logits_ld_distill: loss_kd = self.main_kd( student_distill_pairs['mask_positive_select'], student_distill_pairs['pred_cls_scores'], teacher_distill_pairs['pred_cls_scores'], student_model.yolo_head.num_classes, ) logits_loss += loss_kd else: logits_loss = paddle.zeros([1]) if self.feat_distill and self.loss_weight_feat > 0: feat_loss_list = [] inputs = student_model.inputs assert 'gt_bbox' in inputs assert self.feat_distill_place in student_distill_pairs assert self.feat_distill_place in teacher_distill_pairs stu_feats = student_distill_pairs[self.feat_distill_place] tea_feats = teacher_distill_pairs[self.feat_distill_place] for i, loss_module in enumerate(self.distill_feat_loss_modules): feat_loss_list.append( loss_module(stu_feats[i], tea_feats[i], inputs)) feat_loss = paddle.add_n(feat_loss_list) else: feat_loss = paddle.zeros([1]) student_model.yolo_head.distill_pairs.clear() teacher_model.yolo_head.distill_pairs.clear() return logits_loss * self.loss_weight_logits, feat_loss * self.loss_weight_feat @register class CWDFeatureLoss(nn.Layer): def __init__(self, student_channels, teacher_channels, normalize=False, tau=1.0, weight=1.0): super(CWDFeatureLoss, self).__init__() self.normalize = normalize self.tau = tau self.loss_weight = weight if student_channels != teacher_channels: self.align = nn.Conv2D( student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) else: self.align = None def distill_softmax(self, x, tau): _, _, w, h = paddle.shape(x) x = paddle.reshape(x, [-1, w * h]) x /= tau return F.softmax(x, axis=1) def forward(self, preds_s, preds_t, inputs=None): assert preds_s.shape[-2:] == preds_t.shape[-2:] N, C, H, W = preds_s.shape eps = 1e-5 if self.align is not None: preds_s = self.align(preds_s) if self.normalize: preds_s = feature_norm(preds_s) preds_t = feature_norm(preds_t) softmax_pred_s = self.distill_softmax(preds_s, self.tau) softmax_pred_t = self.distill_softmax(preds_t, self.tau) loss = paddle.sum(-softmax_pred_t * paddle.log(eps + softmax_pred_s) + softmax_pred_t * paddle.log(eps + softmax_pred_t)) return self.loss_weight * loss / (C * N) @register class FGDFeatureLoss(nn.Layer): """ Focal and Global Knowledge Distillation for Detectors The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py Args: student_channels (int): The number of channels in the student's FPN feature map. Default to 256. teacher_channels (int): The number of channels in the teacher's FPN feature map. Default to 256. normalize (bool): Whether to normalize the feature maps. temp (float, optional): The temperature coefficient. Defaults to 0.5. alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 """ def __init__(self, student_channels, teacher_channels, normalize=False, loss_weight=1.0, temp=0.5, alpha_fgd=0.001, beta_fgd=0.0005, gamma_fgd=0.001, lambda_fgd=0.000005): super(FGDFeatureLoss, self).__init__() self.normalize = normalize self.loss_weight = loss_weight self.temp = temp self.alpha_fgd = alpha_fgd self.beta_fgd = beta_fgd self.gamma_fgd = gamma_fgd self.lambda_fgd = lambda_fgd kaiming_init = parameter_init("kaiming") zeros_init = parameter_init("constant", 0.0) if student_channels != teacher_channels: self.align = nn.Conv2D( student_channels, teacher_channels, kernel_size=1, stride=1, padding=0, weight_attr=kaiming_init) student_channels = teacher_channels else: self.align = None self.conv_mask_s = nn.Conv2D( student_channels, 1, kernel_size=1, weight_attr=kaiming_init) self.conv_mask_t = nn.Conv2D( teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) self.stu_conv_block = nn.Sequential( nn.Conv2D( student_channels, student_channels // 2, kernel_size=1, weight_attr=zeros_init), nn.LayerNorm([student_channels // 2, 1, 1]), nn.ReLU(), nn.Conv2D( student_channels // 2, student_channels, kernel_size=1, weight_attr=zeros_init)) self.tea_conv_block = nn.Sequential( nn.Conv2D( teacher_channels, teacher_channels // 2, kernel_size=1, weight_attr=zeros_init), nn.LayerNorm([teacher_channels // 2, 1, 1]), nn.ReLU(), nn.Conv2D( teacher_channels // 2, teacher_channels, kernel_size=1, weight_attr=zeros_init)) def spatial_channel_attention(self, x, t=0.5): shape = paddle.shape(x) N, C, H, W = shape _f = paddle.abs(x) spatial_map = paddle.reshape( paddle.mean( _f, axis=1, keepdim=True) / t, [N, -1]) spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W spatial_att = paddle.reshape(spatial_map, [N, H, W]) channel_map = paddle.mean( paddle.mean( _f, axis=2, keepdim=False), axis=2, keepdim=False) channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C return [spatial_att, channel_att] def spatial_pool(self, x, mode="teacher"): batch, channel, width, height = x.shape x_copy = x x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) x_copy = x_copy.unsqueeze(1) if mode.lower() == "student": context_mask = self.conv_mask_s(x) else: context_mask = self.conv_mask_t(x) context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) context_mask = F.softmax(context_mask, axis=2) context_mask = context_mask.unsqueeze(-1) context = paddle.matmul(x_copy, context_mask) context = paddle.reshape(context, [batch, channel, 1, 1]) return context def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, tea_spatial_att): def _func(a, b): return paddle.sum(paddle.abs(a - b)) / len(a) mask_loss = _func(stu_channel_att, tea_channel_att) + _func( stu_spatial_att, tea_spatial_att) return mask_loss def feature_loss(self, stu_feature, tea_feature, mask_fg, mask_bg, tea_channel_att, tea_spatial_att): mask_fg = mask_fg.unsqueeze(axis=1) mask_bg = mask_bg.unsqueeze(axis=1) tea_channel_att = tea_channel_att.unsqueeze(axis=-1).unsqueeze(axis=-1) tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_fg)) bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_bg)) fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_fg)) bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_bg)) fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(mask_fg) bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(mask_bg) return fg_loss, bg_loss def relation_loss(self, stu_feature, tea_feature): context_s = self.spatial_pool(stu_feature, "student") context_t = self.spatial_pool(tea_feature, "teacher") out_s = stu_feature + self.stu_conv_block(context_s) out_t = tea_feature + self.tea_conv_block(context_t) rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) return rela_loss def mask_value(self, mask, xl, xr, yl, yr, value): mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) return mask def forward(self, stu_feature, tea_feature, inputs): assert stu_feature.shape[-2:] == stu_feature.shape[-2:] assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys() gt_bboxes = inputs['gt_bbox'] ins_shape = [ inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) ] index_gt = [] for i in range(len(gt_bboxes)): if gt_bboxes[i].size > 2: index_gt.append(i) # only distill feature with labeled GTbox if len(index_gt) != len(gt_bboxes): index_gt_t = paddle.to_tensor(index_gt) stu_feature = paddle.index_select(stu_feature, index_gt_t) tea_feature = paddle.index_select(tea_feature, index_gt_t) ins_shape = [ins_shape[c] for c in index_gt] gt_bboxes = [gt_bboxes[c] for c in index_gt] assert len(gt_bboxes) == tea_feature.shape[0] if self.align is not None: stu_feature = self.align(stu_feature) if self.normalize: stu_feature = feature_norm(stu_feature) tea_feature = feature_norm(tea_feature) tea_spatial_att, tea_channel_att = self.spatial_channel_attention( tea_feature, self.temp) stu_spatial_att, stu_channel_att = self.spatial_channel_attention( stu_feature, self.temp) mask_fg = paddle.zeros(tea_spatial_att.shape) mask_bg = paddle.ones_like(tea_spatial_att) one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) mask_fg.stop_gradient = True mask_bg.stop_gradient = True one_tmp.stop_gradient = True zero_tmp.stop_gradient = True wmin, wmax, hmin, hmax = [], [], [], [] if len(gt_bboxes) == 0: loss = self.relation_loss(stu_feature, tea_feature) return self.lambda_fgd * loss N, _, H, W = stu_feature.shape for i in range(N): tmp_box = paddle.ones_like(gt_bboxes[i]) tmp_box.stop_gradient = True tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") zero.stop_gradient = True ones.stop_gradient = True wmin.append( paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) hmin.append( paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) area_recip = 1.0 / ( hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) for j in range(len(gt_bboxes[i])): if gt_bboxes[i][j].sum() > 0: mask_fg[i] = self.mask_value( mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j], wmax[i][j] + 1, area_recip[0][j]) mask_bg[i] = paddle.where(mask_fg[i] > zero_tmp, zero_tmp, one_tmp) if paddle.sum(mask_bg[i]): mask_bg[i] /= paddle.sum(mask_bg[i]) fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, mask_fg, mask_bg, tea_channel_att, tea_spatial_att) mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, stu_spatial_att, tea_spatial_att) rela_loss = self.relation_loss(stu_feature, tea_feature) loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss return loss * self.loss_weight @register class PKDFeatureLoss(nn.Layer): """ PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient. Args: loss_weight (float): Weight of loss. Defaults to 1.0. resize_stu (bool): If True, we'll down/up sample the features of the student model to the spatial size of those of the teacher model if their spatial sizes are different. And vice versa. Defaults to True. """ def __init__(self, student_channels=256, teacher_channels=256, normalize=True, loss_weight=1.0, resize_stu=True): super(PKDFeatureLoss, self).__init__() self.normalize = normalize self.loss_weight = loss_weight self.resize_stu = resize_stu def forward(self, stu_feature, tea_feature, inputs=None): size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] if size_s[0] != size_t[0]: if self.resize_stu: stu_feature = F.interpolate( stu_feature, size_t, mode='bilinear') else: tea_feature = F.interpolate( tea_feature, size_s, mode='bilinear') assert stu_feature.shape == tea_feature.shape if self.normalize: stu_feature = feature_norm(stu_feature) tea_feature = feature_norm(tea_feature) loss = F.mse_loss(stu_feature, tea_feature) / 2 return loss * self.loss_weight @register class MimicFeatureLoss(nn.Layer): def __init__(self, student_channels=256, teacher_channels=256, normalize=True, loss_weight=1.0): super(MimicFeatureLoss, self).__init__() self.normalize = normalize self.loss_weight = loss_weight self.mse_loss = nn.MSELoss() if student_channels != teacher_channels: self.align = nn.Conv2D( student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) else: self.align = None def forward(self, stu_feature, tea_feature, inputs=None): if self.align is not None: stu_feature = self.align(stu_feature) if self.normalize: stu_feature = feature_norm(stu_feature) tea_feature = feature_norm(tea_feature) loss = self.mse_loss(stu_feature, tea_feature) return loss * self.loss_weight @register class MGDFeatureLoss(nn.Layer): def __init__(self, student_channels=256, teacher_channels=256, normalize=True, loss_weight=1.0, loss_func='mse'): super(MGDFeatureLoss, self).__init__() self.normalize = normalize self.loss_weight = loss_weight assert loss_func in ['mse', 'ssim'] self.loss_func = loss_func self.mse_loss = nn.MSELoss(reduction='sum') self.ssim_loss = SSIM(11) kaiming_init = parameter_init("kaiming") if student_channels != teacher_channels: self.align = nn.Conv2D( student_channels, teacher_channels, kernel_size=1, stride=1, padding=0, weight_attr=kaiming_init, bias_attr=False) else: self.align = None self.generation = nn.Sequential( nn.Conv2D( teacher_channels, teacher_channels, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2D( teacher_channels, teacher_channels, kernel_size=3, padding=1)) def forward(self, stu_feature, tea_feature, inputs=None): N = stu_feature.shape[0] if self.align is not None: stu_feature = self.align(stu_feature) stu_feature = self.generation(stu_feature) if self.normalize: stu_feature = feature_norm(stu_feature) tea_feature = feature_norm(tea_feature) if self.loss_func == 'mse': loss = self.mse_loss(stu_feature, tea_feature) / N elif self.loss_func == 'ssim': ssim_loss = self.ssim_loss(stu_feature, tea_feature) loss = paddle.clip((1 - ssim_loss) / 2, 0, 1) else: raise ValueError return loss * self.loss_weight class SSIM(nn.Layer): def __init__(self, window_size=11, size_average=True): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = 1 self.window = self.create_window(window_size, self.channel) def gaussian(self, window_size, sigma): gauss = paddle.to_tensor([ math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) for x in range(window_size) ]) return gauss / gauss.sum() def create_window(self, window_size, channel): _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) window = _2D_window.expand([channel, 1, window_size, window_size]) return window def _ssim(self, img1, img2, window, window_size, channel, size_average=True): mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d( img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq sigma2_sq = F.conv2d( img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d( img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 1e-12 + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean([1, 2, 3]) def forward(self, img1, img2): channel = img1.shape[1] if channel == self.channel and self.window.dtype == img1.dtype: window = self.window else: window = self.create_window(self.window_size, channel) self.window = window self.channel = channel return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)