251 lines
7.8 KiB
Python
251 lines
7.8 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from itertools import cycle, islice
|
|
from collections import abc
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
from ppdet.core.workspace import register, serializable
|
|
from ppdet.utils.logger import setup_logger
|
|
logger = setup_logger('ppdet.engine')
|
|
|
|
__all__ = ['Pose3DLoss']
|
|
|
|
|
|
@register
|
|
@serializable
|
|
class Pose3DLoss(nn.Layer):
|
|
def __init__(self, weight_3d=1.0, weight_2d=0.0, reduction='none'):
|
|
"""
|
|
KeyPointMSELoss layer
|
|
|
|
Args:
|
|
weight_3d (float): weight of 3d loss
|
|
weight_2d (float): weight of 2d loss
|
|
reduction (bool): whether use reduction to loss
|
|
"""
|
|
super(Pose3DLoss, self).__init__()
|
|
self.weight_3d = weight_3d
|
|
self.weight_2d = weight_2d
|
|
self.criterion_2dpose = nn.MSELoss(reduction=reduction)
|
|
self.criterion_3dpose = nn.L1Loss(reduction=reduction)
|
|
self.criterion_smoothl1 = nn.SmoothL1Loss(
|
|
reduction=reduction, delta=1.0)
|
|
self.criterion_vertices = nn.L1Loss()
|
|
|
|
def forward(self, pred3d, pred2d, inputs):
|
|
"""
|
|
mpjpe: mpjpe loss between 3d joints
|
|
keypoint_2d_loss: 2d joints loss compute by criterion_2dpose
|
|
"""
|
|
gt_3d_joints = inputs['joints_3d']
|
|
gt_2d_joints = inputs['joints_2d']
|
|
has_3d_joints = inputs['has_3d_joints']
|
|
has_2d_joints = inputs['has_2d_joints']
|
|
|
|
loss_3d = mpjpe_focal(pred3d, gt_3d_joints, has_3d_joints)
|
|
loss = self.weight_3d * loss_3d
|
|
epoch = inputs['epoch_id']
|
|
if self.weight_2d > 0:
|
|
weight = self.weight_2d * pow(0.1, (epoch // 8))
|
|
if epoch > 8:
|
|
weight = 0
|
|
loss_2d = keypoint_2d_loss(self.criterion_2dpose, pred2d,
|
|
gt_2d_joints, has_2d_joints)
|
|
loss += weight * loss_2d
|
|
return loss
|
|
|
|
|
|
def filter_3d_joints(pred, gt, has_3d_joints):
|
|
"""
|
|
filter 3d joints
|
|
"""
|
|
gt = gt[has_3d_joints == 1]
|
|
gt = gt[:, :, :3]
|
|
pred = pred[has_3d_joints == 1]
|
|
|
|
gt_pelvis = (gt[:, 2, :] + gt[:, 3, :]) / 2
|
|
gt = gt - gt_pelvis[:, None, :]
|
|
pred_pelvis = (pred[:, 2, :] + pred[:, 3, :]) / 2
|
|
pred = pred - pred_pelvis[:, None, :]
|
|
return pred, gt
|
|
|
|
|
|
def mpjpe(pred, gt, has_3d_joints):
|
|
"""
|
|
mPJPE loss
|
|
"""
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
error = paddle.sqrt((paddle.minimum((pred - gt), paddle.to_tensor(1.2))**2
|
|
).sum(axis=-1)).mean()
|
|
return error
|
|
|
|
|
|
def mpjpe_focal(pred, gt, has_3d_joints):
|
|
"""
|
|
mPJPE loss
|
|
"""
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
mse_error = ((pred - gt)**2).sum(axis=-1)
|
|
mpjpe_error = paddle.sqrt(mse_error)
|
|
mean = mpjpe_error.mean()
|
|
std = mpjpe_error.std()
|
|
atte = 2 * F.sigmoid(6 * (mpjpe_error - mean) / std)
|
|
mse_error *= atte
|
|
return mse_error.mean()
|
|
|
|
|
|
def mpjpe_mse(pred, gt, has_3d_joints, weight=1.):
|
|
"""
|
|
mPJPE loss
|
|
"""
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
error = (((pred - gt)**2).sum(axis=-1)).mean()
|
|
return error
|
|
|
|
|
|
def mpjpe_criterion(pred, gt, has_3d_joints, criterion_pose3d):
|
|
"""
|
|
mPJPE loss of self define criterion
|
|
"""
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
error = paddle.sqrt(criterion_pose3d(pred, gt)).mean()
|
|
return error
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def weighted_mpjpe(pred, gt, has_3d_joints):
|
|
"""
|
|
Weighted_mPJPE
|
|
"""
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
weight = paddle.linalg.norm(pred, p=2, axis=-1)
|
|
weight = paddle.to_tensor(
|
|
[1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1., 1.])
|
|
error = (weight * paddle.linalg.norm(pred - gt, p=2, axis=-1)).mean()
|
|
return error
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def normed_mpjpe(pred, gt, has_3d_joints):
|
|
"""
|
|
Normalized MPJPE (scale only), adapted from:
|
|
https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
|
|
"""
|
|
assert pred.shape == gt.shape
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
|
|
norm_predicted = paddle.mean(
|
|
paddle.sum(pred**2, axis=3, keepdim=True), axis=2, keepdim=True)
|
|
norm_target = paddle.mean(
|
|
paddle.sum(gt * pred, axis=3, keepdim=True), axis=2, keepdim=True)
|
|
scale = norm_target / norm_predicted
|
|
return mpjpe(scale * pred, gt)
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def mpjpe_np(pred, gt, has_3d_joints):
|
|
"""
|
|
mPJPE_NP
|
|
"""
|
|
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
|
|
error = np.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
|
|
return error
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def mean_per_vertex_error(pred, gt, has_smpl):
|
|
"""
|
|
Compute mPVE
|
|
"""
|
|
pred = pred[has_smpl == 1]
|
|
gt = gt[has_smpl == 1]
|
|
with paddle.no_grad():
|
|
error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
|
|
return error
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d,
|
|
has_pose_2d):
|
|
"""
|
|
Compute 2D reprojection loss if 2D keypoint annotations are available.
|
|
The confidence (conf) is binary and indicates whether the keypoints exist or not.
|
|
"""
|
|
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
|
|
loss = (conf * criterion_keypoints(
|
|
pred_keypoints_2d, gt_keypoints_2d[:, :, :-1] * 0.001)).mean()
|
|
return loss
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d,
|
|
has_pose_3d):
|
|
"""
|
|
Compute 3D keypoint loss if 3D keypoint annotations are available.
|
|
"""
|
|
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
|
|
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
|
|
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
|
|
conf = conf[has_pose_3d == 1]
|
|
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
|
|
if len(gt_keypoints_3d) > 0:
|
|
gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2
|
|
gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
|
|
pred_pelvis = (
|
|
pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2
|
|
pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
|
|
return (conf * criterion_keypoints(pred_keypoints_3d,
|
|
gt_keypoints_3d)).mean()
|
|
else:
|
|
return paddle.to_tensor([1.]).fill_(0.)
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
|
|
"""
|
|
Compute per-vertex loss if vertex annotations are available.
|
|
"""
|
|
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
|
|
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
|
|
if len(gt_vertices_with_shape) > 0:
|
|
return criterion_vertices(pred_vertices_with_shape,
|
|
gt_vertices_with_shape)
|
|
else:
|
|
return paddle.to_tensor([1.]).fill_(0.)
|
|
|
|
|
|
@register
|
|
@serializable
|
|
def rectify_pose(pose):
|
|
pose = pose.copy()
|
|
R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
|
|
R_root = cv2.Rodrigues(pose[:3])[0]
|
|
new_root = R_root.dot(R_mod)
|
|
pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
|
|
return pose
|