Files
2024-08-27 14:42:45 +08:00

100 lines
3.2 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ["SparseRCNN"]
@register
class SparseRCNN(BaseArch):
__category__ = 'architecture'
__inject__ = ["postprocess"]
def __init__(self,
backbone,
neck,
head="SparsercnnHead",
postprocess="SparsePostProcess"):
super(SparseRCNN, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.postprocess = postprocess
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'roi_input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
"head": head,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
head_outs = self.head(fpn_feats, self.inputs["img_whwh"])
if not self.training:
bbox_pred, bbox_num = self.postprocess(
head_outs["pred_logits"], head_outs["pred_boxes"],
self.inputs["scale_factor_whwh"], self.inputs["ori_shape"])
return bbox_pred, bbox_num
else:
return head_outs
def get_loss(self):
batch_gt_class = self.inputs["gt_class"]
batch_gt_box = self.inputs["gt_bbox"]
batch_whwh = self.inputs["img_whwh"]
targets = []
for i in range(len(batch_gt_class)):
boxes = batch_gt_box[i]
labels = batch_gt_class[i].squeeze(-1)
img_whwh = batch_whwh[i]
img_whwh_tgt = img_whwh.unsqueeze(0).tile([int(boxes.shape[0]), 1])
targets.append({
"boxes": boxes,
"labels": labels,
"img_whwh": img_whwh,
"img_whwh_tgt": img_whwh_tgt
})
outputs = self._forward()
loss_dict = self.head.get_loss(outputs, targets)
acc = loss_dict["acc"]
loss_dict.pop("acc")
total_loss = sum(loss_dict.values())
loss_dict.update({"loss": total_loss, "acc": acc})
return loss_dict
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output