Files
2024-08-27 14:42:45 +08:00

279 lines
9.7 KiB
Python

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.keypoint_utils import resize, flip_back
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
from ppdet.modeling.layers import ConvTranspose2d, BatchNorm2d
trunc_normal_ = TruncatedNormal(std=.02)
normal_ = Normal(std=0.001)
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
__all__ = ['TopdownHeatmapSimpleHead']
@register
class TopdownHeatmapSimpleHead(nn.Layer):
def __init__(self,
in_channels=768,
out_channels=17,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=None,
in_index=0,
input_transform=None,
align_corners=False,
upsample=0,
flip_pairs=None,
shift_heatmap=False,
target_type='GaussianHeatmap'):
super(TopdownHeatmapSimpleHead, self).__init__()
self.in_channels = in_channels
self.upsample = upsample
self.flip_pairs = flip_pairs
self.shift_heatmap = shift_heatmap
self.target_type = target_type
self._init_inputs(in_channels, in_index, input_transform)
self.in_index = in_index
self.align_corners = align_corners
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
if num_deconv_layers > 0:
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels, )
elif num_deconv_layers == 0:
self.deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
if identity_final_layer:
self.final_layer = nn.Identity()
else:
conv_channels = num_deconv_filters[
-1] if num_deconv_layers > 0 else self.in_channels
layers = []
if extra is not None:
num_conv_layers = extra.get('num_conv_layers', 0)
num_conv_kernels = extra.get('num_conv_kernels',
[1] * num_conv_layers)
for i in range(num_conv_layers):
layers.append(
nn.Conv2D(
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=num_conv_kernels[i],
stride=1,
padding=(num_conv_kernels[i] - 1) // 2))
layers.append(nn.BatchNorm2D(conv_channels))
layers.append(nn.ReLU())
layers.append(
nn.Conv2D(
in_channels=conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=(padding, padding)))
if len(layers) > 1:
self.final_layer = nn.Sequential(*layers)
else:
self.final_layer = layers[0]
self.init_weights()
@staticmethod
def _get_deconv_cfg(deconv_kernel):
"""Get configurations for deconv layers."""
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
return deconv_kernel, padding, output_padding
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
"""
if not isinstance(inputs, list):
if not isinstance(inputs, list):
if self.upsample > 0:
inputs = resize(
input=F.relu(inputs),
scale_factor=self.upsample,
mode='bilinear',
align_corners=self.align_corners)
return inputs
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = paddle.concat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def forward(self, x):
"""Forward function."""
x = self._transform_inputs(x)
x = self.deconv_layers(x)
x = self.final_layer(x)
return x
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_heatmap (np.ndarray): Output heatmaps.
Args:
x (torch.Tensor[N,K,H,W]): Input features.
flip_pairs (None | list[tuple]):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
output_heatmap = flip_back(
output, self.flip_pairs, target_type=self.target_type)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.shift_heatmap:
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
else:
output_heatmap = output
return output_heatmap
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
ConvTranspose2d(
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2D(planes))
layers.append(nn.ReLU())
self.in_channels = planes
return nn.Sequential(*layers)
def init_weights(self):
"""Initialize model weights."""
if not isinstance(self.deconv_layers, nn.Identity):
for m in self.deconv_layers:
if isinstance(m, nn.BatchNorm2D):
ones_(m.weight)
ones_(m.bias)
if not isinstance(self.final_layer, nn.Conv2D):
for m in self.final_layer:
if isinstance(m, nn.Conv2D):
normal_(m.weight)
zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2D):
ones_(m.weight)
ones_(m.bias)
else:
normal_(self.final_layer.weight)
zeros_(self.final_layer.bias)