更换文档检测模型
This commit is contained in:
352
paddle_detection/ppdet/slim/distill_model.py
Normal file
352
paddle_detection/ppdet/slim/distill_model.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from ppdet.core.workspace import register, create, load_config
|
||||
from ppdet.utils.checkpoint import load_pretrain_weight
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'DistillModel',
|
||||
'FGDDistillModel',
|
||||
'CWDDistillModel',
|
||||
'LDDistillModel',
|
||||
'PPYOLOEDistillModel',
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class DistillModel(nn.Layer):
|
||||
"""
|
||||
Build common distill model.
|
||||
Args:
|
||||
cfg: The student config.
|
||||
slim_cfg: The teacher and distill config.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, slim_cfg):
|
||||
super(DistillModel, self).__init__()
|
||||
self.arch = cfg.architecture
|
||||
|
||||
self.stu_cfg = cfg
|
||||
self.student_model = create(self.stu_cfg.architecture)
|
||||
if 'pretrain_weights' in self.stu_cfg and self.stu_cfg.pretrain_weights:
|
||||
stu_pretrain = self.stu_cfg.pretrain_weights
|
||||
else:
|
||||
stu_pretrain = None
|
||||
|
||||
slim_cfg = load_config(slim_cfg)
|
||||
self.tea_cfg = slim_cfg
|
||||
self.teacher_model = create(self.tea_cfg.architecture)
|
||||
if 'pretrain_weights' in self.tea_cfg and self.tea_cfg.pretrain_weights:
|
||||
tea_pretrain = self.tea_cfg.pretrain_weights
|
||||
else:
|
||||
tea_pretrain = None
|
||||
self.distill_cfg = slim_cfg
|
||||
|
||||
# load pretrain weights
|
||||
self.is_inherit = False
|
||||
if stu_pretrain:
|
||||
if self.is_inherit and tea_pretrain:
|
||||
load_pretrain_weight(self.student_model, tea_pretrain)
|
||||
logger.debug(
|
||||
"Inheriting! loading teacher weights to student model!")
|
||||
load_pretrain_weight(self.student_model, stu_pretrain)
|
||||
logger.info("Student model has loaded pretrain weights!")
|
||||
if tea_pretrain:
|
||||
load_pretrain_weight(self.teacher_model, tea_pretrain)
|
||||
logger.info("Teacher model has loaded pretrain weights!")
|
||||
|
||||
self.teacher_model.eval()
|
||||
for param in self.teacher_model.parameters():
|
||||
param.trainable = False
|
||||
|
||||
self.distill_loss = self.build_loss(self.distill_cfg)
|
||||
|
||||
def build_loss(self, distill_cfg):
|
||||
if 'distill_loss' in distill_cfg and distill_cfg.distill_loss:
|
||||
return create(distill_cfg.distill_loss)
|
||||
else:
|
||||
return None
|
||||
|
||||
def parameters(self):
|
||||
return self.student_model.parameters()
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.training:
|
||||
student_loss = self.student_model(inputs)
|
||||
with paddle.no_grad():
|
||||
teacher_loss = self.teacher_model(inputs)
|
||||
|
||||
loss = self.distill_loss(self.teacher_model, self.student_model)
|
||||
student_loss['distill_loss'] = loss
|
||||
student_loss['teacher_loss'] = teacher_loss['loss']
|
||||
student_loss['loss'] += student_loss['distill_loss']
|
||||
return student_loss
|
||||
else:
|
||||
return self.student_model(inputs)
|
||||
|
||||
|
||||
@register
|
||||
class FGDDistillModel(DistillModel):
|
||||
"""
|
||||
Build FGD distill model.
|
||||
Args:
|
||||
cfg: The student config.
|
||||
slim_cfg: The teacher and distill config.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, slim_cfg):
|
||||
super(FGDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
|
||||
assert self.arch in ['RetinaNet', 'PicoDet'
|
||||
], 'Unsupported arch: {}'.format(self.arch)
|
||||
self.is_inherit = True
|
||||
|
||||
def build_loss(self, distill_cfg):
|
||||
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
|
||||
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
|
||||
loss_func = dict()
|
||||
name_list = distill_cfg.distill_loss_name
|
||||
for name in name_list:
|
||||
loss_func[name] = create(distill_cfg.distill_loss)
|
||||
return loss_func
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.training:
|
||||
s_body_feats = self.student_model.backbone(inputs)
|
||||
s_neck_feats = self.student_model.neck(s_body_feats)
|
||||
with paddle.no_grad():
|
||||
t_body_feats = self.teacher_model.backbone(inputs)
|
||||
t_neck_feats = self.teacher_model.neck(t_body_feats)
|
||||
|
||||
loss_dict = {}
|
||||
for idx, k in enumerate(self.distill_loss):
|
||||
loss_dict[k] = self.distill_loss[k](s_neck_feats[idx],
|
||||
t_neck_feats[idx], inputs)
|
||||
if self.arch == "RetinaNet":
|
||||
loss = self.student_model.head(s_neck_feats, inputs)
|
||||
elif self.arch == "PicoDet":
|
||||
head_outs = self.student_model.head(
|
||||
s_neck_feats, self.student_model.export_post_process)
|
||||
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
|
||||
total_loss = paddle.add_n(list(loss_gfl.values()))
|
||||
loss = {}
|
||||
loss.update(loss_gfl)
|
||||
loss.update({'loss': total_loss})
|
||||
else:
|
||||
raise ValueError(f"Unsupported model {self.arch}")
|
||||
|
||||
for k in loss_dict:
|
||||
loss['loss'] += loss_dict[k]
|
||||
loss[k] = loss_dict[k]
|
||||
return loss
|
||||
else:
|
||||
body_feats = self.student_model.backbone(inputs)
|
||||
neck_feats = self.student_model.neck(body_feats)
|
||||
head_outs = self.student_model.head(neck_feats)
|
||||
if self.arch == "RetinaNet":
|
||||
bbox, bbox_num = self.student_model.head.post_process(
|
||||
head_outs, inputs['im_shape'], inputs['scale_factor'])
|
||||
return {'bbox': bbox, 'bbox_num': bbox_num}
|
||||
elif self.arch == "PicoDet":
|
||||
head_outs = self.student_model.head(
|
||||
neck_feats, self.student_model.export_post_process)
|
||||
scale_factor = inputs['scale_factor']
|
||||
bboxes, bbox_num = self.student_model.head.post_process(
|
||||
head_outs,
|
||||
scale_factor,
|
||||
export_nms=self.student_model.export_nms)
|
||||
return {'bbox': bboxes, 'bbox_num': bbox_num}
|
||||
else:
|
||||
raise ValueError(f"Unsupported model {self.arch}")
|
||||
|
||||
|
||||
@register
|
||||
class CWDDistillModel(DistillModel):
|
||||
"""
|
||||
Build CWD distill model.
|
||||
Args:
|
||||
cfg: The student config.
|
||||
slim_cfg: The teacher and distill config.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, slim_cfg):
|
||||
super(CWDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
|
||||
assert self.arch in ['GFL', 'RetinaNet'], 'Unsupported arch: {}'.format(
|
||||
self.arch)
|
||||
|
||||
def build_loss(self, distill_cfg):
|
||||
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
|
||||
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
|
||||
loss_func = dict()
|
||||
name_list = distill_cfg.distill_loss_name
|
||||
for name in name_list:
|
||||
loss_func[name] = create(distill_cfg.distill_loss)
|
||||
return loss_func
|
||||
|
||||
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
|
||||
loss = self.student_model.head(stu_fea_list, inputs)
|
||||
loss_dict = {}
|
||||
for idx, k in enumerate(self.distill_loss):
|
||||
loss_dict[k] = self.distill_loss[k](stu_fea_list[idx],
|
||||
tea_fea_list[idx])
|
||||
|
||||
loss['loss'] += loss_dict[k]
|
||||
loss[k] = loss_dict[k]
|
||||
return loss
|
||||
|
||||
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
|
||||
loss = {}
|
||||
head_outs = self.student_model.head(stu_fea_list)
|
||||
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
|
||||
loss.update(loss_gfl)
|
||||
total_loss = paddle.add_n(list(loss.values()))
|
||||
loss.update({'loss': total_loss})
|
||||
|
||||
feat_loss = {}
|
||||
loss_dict = {}
|
||||
s_cls_feat, t_cls_feat = [], []
|
||||
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
|
||||
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
|
||||
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
|
||||
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
|
||||
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
|
||||
s_cls_feat.append(cls_score)
|
||||
t_cls_feat.append(t_cls_score)
|
||||
|
||||
for idx, k in enumerate(self.distill_loss):
|
||||
loss_dict[k] = self.distill_loss[k](s_cls_feat[idx],
|
||||
t_cls_feat[idx])
|
||||
feat_loss[f"neck_f_{idx}"] = self.distill_loss[k](stu_fea_list[idx],
|
||||
tea_fea_list[idx])
|
||||
|
||||
for k in feat_loss:
|
||||
loss['loss'] += feat_loss[k]
|
||||
loss[k] = feat_loss[k]
|
||||
|
||||
for k in loss_dict:
|
||||
loss['loss'] += loss_dict[k]
|
||||
loss[k] = loss_dict[k]
|
||||
return loss
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.training:
|
||||
s_body_feats = self.student_model.backbone(inputs)
|
||||
s_neck_feats = self.student_model.neck(s_body_feats)
|
||||
with paddle.no_grad():
|
||||
t_body_feats = self.teacher_model.backbone(inputs)
|
||||
t_neck_feats = self.teacher_model.neck(t_body_feats)
|
||||
|
||||
if self.arch == "RetinaNet":
|
||||
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
|
||||
inputs)
|
||||
elif self.arch == "GFL":
|
||||
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
|
||||
else:
|
||||
raise ValueError(f"unsupported arch {self.arch}")
|
||||
return loss
|
||||
else:
|
||||
body_feats = self.student_model.backbone(inputs)
|
||||
neck_feats = self.student_model.neck(body_feats)
|
||||
head_outs = self.student_model.head(neck_feats)
|
||||
if self.arch == "RetinaNet":
|
||||
bbox, bbox_num = self.student_model.head.post_process(
|
||||
head_outs, inputs['im_shape'], inputs['scale_factor'])
|
||||
return {'bbox': bbox, 'bbox_num': bbox_num}
|
||||
elif self.arch == "GFL":
|
||||
bbox_pred, bbox_num = head_outs
|
||||
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
|
||||
return output
|
||||
else:
|
||||
raise ValueError(f"unsupported arch {self.arch}")
|
||||
|
||||
|
||||
@register
|
||||
class LDDistillModel(DistillModel):
|
||||
"""
|
||||
Build LD distill model.
|
||||
Args:
|
||||
cfg: The student config.
|
||||
slim_cfg: The teacher and distill config.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, slim_cfg):
|
||||
super(LDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
|
||||
assert self.arch in ['GFL'], 'Unsupported arch: {}'.format(self.arch)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.training:
|
||||
s_body_feats = self.student_model.backbone(inputs)
|
||||
s_neck_feats = self.student_model.neck(s_body_feats)
|
||||
s_head_outs = self.student_model.head(s_neck_feats)
|
||||
with paddle.no_grad():
|
||||
t_body_feats = self.teacher_model.backbone(inputs)
|
||||
t_neck_feats = self.teacher_model.neck(t_body_feats)
|
||||
t_head_outs = self.teacher_model.head(t_neck_feats)
|
||||
|
||||
soft_label_list = t_head_outs[0]
|
||||
soft_targets_list = t_head_outs[1]
|
||||
student_loss = self.student_model.head.get_loss(
|
||||
s_head_outs, inputs, soft_label_list, soft_targets_list)
|
||||
total_loss = paddle.add_n(list(student_loss.values()))
|
||||
student_loss['loss'] = total_loss
|
||||
return student_loss
|
||||
else:
|
||||
return self.student_model(inputs)
|
||||
|
||||
|
||||
@register
|
||||
class PPYOLOEDistillModel(DistillModel):
|
||||
"""
|
||||
Build PPYOLOE distill model, only used in PPYOLOE
|
||||
Args:
|
||||
cfg: The student config.
|
||||
slim_cfg: The teacher and distill config.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, slim_cfg):
|
||||
super(PPYOLOEDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
|
||||
assert self.arch in ['PPYOLOE'], 'Unsupported arch: {}'.format(
|
||||
self.arch)
|
||||
|
||||
def forward(self, inputs, alpha=0.125):
|
||||
if self.training:
|
||||
with paddle.no_grad():
|
||||
teacher_loss = self.teacher_model(inputs)
|
||||
if hasattr(self.teacher_model.yolo_head, "assigned_labels"):
|
||||
self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores = \
|
||||
self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores
|
||||
delattr(self.teacher_model.yolo_head, "assigned_labels")
|
||||
delattr(self.teacher_model.yolo_head, "assigned_bboxes")
|
||||
delattr(self.teacher_model.yolo_head, "assigned_scores")
|
||||
student_loss = self.student_model(inputs)
|
||||
|
||||
logits_loss, feat_loss = self.distill_loss(self.teacher_model,
|
||||
self.student_model)
|
||||
det_total_loss = student_loss['loss']
|
||||
total_loss = alpha * (det_total_loss + logits_loss + feat_loss)
|
||||
student_loss['loss'] = total_loss
|
||||
student_loss['det_loss'] = det_total_loss
|
||||
student_loss['logits_loss'] = logits_loss
|
||||
student_loss['feat_loss'] = feat_loss
|
||||
return student_loss
|
||||
else:
|
||||
return self.student_model(inputs)
|
||||
Reference in New Issue
Block a user