279 lines
9.7 KiB
Python
279 lines
9.7 KiB
Python
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
from ppdet.core.workspace import register
|
|
from ppdet.modeling.keypoint_utils import resize, flip_back
|
|
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
|
from ppdet.modeling.layers import ConvTranspose2d, BatchNorm2d
|
|
|
|
trunc_normal_ = TruncatedNormal(std=.02)
|
|
normal_ = Normal(std=0.001)
|
|
zeros_ = Constant(value=0.)
|
|
ones_ = Constant(value=1.)
|
|
|
|
__all__ = ['TopdownHeatmapSimpleHead']
|
|
|
|
|
|
@register
|
|
class TopdownHeatmapSimpleHead(nn.Layer):
|
|
def __init__(self,
|
|
in_channels=768,
|
|
out_channels=17,
|
|
num_deconv_layers=3,
|
|
num_deconv_filters=(256, 256, 256),
|
|
num_deconv_kernels=(4, 4, 4),
|
|
extra=None,
|
|
in_index=0,
|
|
input_transform=None,
|
|
align_corners=False,
|
|
upsample=0,
|
|
flip_pairs=None,
|
|
shift_heatmap=False,
|
|
target_type='GaussianHeatmap'):
|
|
super(TopdownHeatmapSimpleHead, self).__init__()
|
|
|
|
self.in_channels = in_channels
|
|
self.upsample = upsample
|
|
self.flip_pairs = flip_pairs
|
|
self.shift_heatmap = shift_heatmap
|
|
self.target_type = target_type
|
|
|
|
self._init_inputs(in_channels, in_index, input_transform)
|
|
self.in_index = in_index
|
|
self.align_corners = align_corners
|
|
|
|
if extra is not None and not isinstance(extra, dict):
|
|
raise TypeError('extra should be dict or None.')
|
|
|
|
if num_deconv_layers > 0:
|
|
self.deconv_layers = self._make_deconv_layer(
|
|
num_deconv_layers,
|
|
num_deconv_filters,
|
|
num_deconv_kernels, )
|
|
elif num_deconv_layers == 0:
|
|
self.deconv_layers = nn.Identity()
|
|
else:
|
|
raise ValueError(
|
|
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
|
|
|
|
identity_final_layer = False
|
|
if extra is not None and 'final_conv_kernel' in extra:
|
|
assert extra['final_conv_kernel'] in [0, 1, 3]
|
|
if extra['final_conv_kernel'] == 3:
|
|
padding = 1
|
|
elif extra['final_conv_kernel'] == 1:
|
|
padding = 0
|
|
else:
|
|
# 0 for Identity mapping.
|
|
identity_final_layer = True
|
|
kernel_size = extra['final_conv_kernel']
|
|
else:
|
|
kernel_size = 1
|
|
padding = 0
|
|
|
|
if identity_final_layer:
|
|
self.final_layer = nn.Identity()
|
|
else:
|
|
conv_channels = num_deconv_filters[
|
|
-1] if num_deconv_layers > 0 else self.in_channels
|
|
|
|
layers = []
|
|
if extra is not None:
|
|
num_conv_layers = extra.get('num_conv_layers', 0)
|
|
num_conv_kernels = extra.get('num_conv_kernels',
|
|
[1] * num_conv_layers)
|
|
|
|
for i in range(num_conv_layers):
|
|
layers.append(
|
|
nn.Conv2D(
|
|
in_channels=conv_channels,
|
|
out_channels=conv_channels,
|
|
kernel_size=num_conv_kernels[i],
|
|
stride=1,
|
|
padding=(num_conv_kernels[i] - 1) // 2))
|
|
layers.append(nn.BatchNorm2D(conv_channels))
|
|
layers.append(nn.ReLU())
|
|
|
|
layers.append(
|
|
nn.Conv2D(
|
|
in_channels=conv_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=(padding, padding)))
|
|
|
|
if len(layers) > 1:
|
|
self.final_layer = nn.Sequential(*layers)
|
|
else:
|
|
self.final_layer = layers[0]
|
|
|
|
self.init_weights()
|
|
|
|
@staticmethod
|
|
def _get_deconv_cfg(deconv_kernel):
|
|
"""Get configurations for deconv layers."""
|
|
if deconv_kernel == 4:
|
|
padding = 1
|
|
output_padding = 0
|
|
elif deconv_kernel == 3:
|
|
padding = 1
|
|
output_padding = 1
|
|
elif deconv_kernel == 2:
|
|
padding = 0
|
|
output_padding = 0
|
|
else:
|
|
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
|
|
|
return deconv_kernel, padding, output_padding
|
|
|
|
def _init_inputs(self, in_channels, in_index, input_transform):
|
|
"""Check and initialize input transforms.
|
|
"""
|
|
|
|
if input_transform is not None:
|
|
assert input_transform in ['resize_concat', 'multiple_select']
|
|
self.input_transform = input_transform
|
|
self.in_index = in_index
|
|
if input_transform is not None:
|
|
assert isinstance(in_channels, (list, tuple))
|
|
assert isinstance(in_index, (list, tuple))
|
|
assert len(in_channels) == len(in_index)
|
|
if input_transform == 'resize_concat':
|
|
self.in_channels = sum(in_channels)
|
|
else:
|
|
self.in_channels = in_channels
|
|
else:
|
|
assert isinstance(in_channels, int)
|
|
assert isinstance(in_index, int)
|
|
self.in_channels = in_channels
|
|
|
|
def _transform_inputs(self, inputs):
|
|
"""Transform inputs for decoder.
|
|
"""
|
|
if not isinstance(inputs, list):
|
|
if not isinstance(inputs, list):
|
|
|
|
if self.upsample > 0:
|
|
inputs = resize(
|
|
input=F.relu(inputs),
|
|
scale_factor=self.upsample,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
return inputs
|
|
|
|
if self.input_transform == 'resize_concat':
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
upsampled_inputs = [
|
|
resize(
|
|
input=x,
|
|
size=inputs[0].shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners) for x in inputs
|
|
]
|
|
inputs = paddle.concat(upsampled_inputs, dim=1)
|
|
elif self.input_transform == 'multiple_select':
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
else:
|
|
inputs = inputs[self.in_index]
|
|
|
|
return inputs
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
x = self._transform_inputs(x)
|
|
x = self.deconv_layers(x)
|
|
x = self.final_layer(x)
|
|
|
|
return x
|
|
|
|
def inference_model(self, x, flip_pairs=None):
|
|
"""Inference function.
|
|
|
|
Returns:
|
|
output_heatmap (np.ndarray): Output heatmaps.
|
|
|
|
Args:
|
|
x (torch.Tensor[N,K,H,W]): Input features.
|
|
flip_pairs (None | list[tuple]):
|
|
Pairs of keypoints which are mirrored.
|
|
"""
|
|
output = self.forward(x)
|
|
|
|
if flip_pairs is not None:
|
|
output_heatmap = flip_back(
|
|
output, self.flip_pairs, target_type=self.target_type)
|
|
# feature is not aligned, shift flipped heatmap for higher accuracy
|
|
if self.shift_heatmap:
|
|
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
|
|
else:
|
|
output_heatmap = output
|
|
return output_heatmap
|
|
|
|
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
|
"""Make deconv layers."""
|
|
if num_layers != len(num_filters):
|
|
error_msg = f'num_layers({num_layers}) ' \
|
|
f'!= length of num_filters({len(num_filters)})'
|
|
raise ValueError(error_msg)
|
|
if num_layers != len(num_kernels):
|
|
error_msg = f'num_layers({num_layers}) ' \
|
|
f'!= length of num_kernels({len(num_kernels)})'
|
|
raise ValueError(error_msg)
|
|
|
|
layers = []
|
|
for i in range(num_layers):
|
|
kernel, padding, output_padding = \
|
|
self._get_deconv_cfg(num_kernels[i])
|
|
|
|
planes = num_filters[i]
|
|
layers.append(
|
|
ConvTranspose2d(
|
|
in_channels=self.in_channels,
|
|
out_channels=planes,
|
|
kernel_size=kernel,
|
|
stride=2,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
bias=False))
|
|
layers.append(nn.BatchNorm2D(planes))
|
|
layers.append(nn.ReLU())
|
|
self.in_channels = planes
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def init_weights(self):
|
|
"""Initialize model weights."""
|
|
if not isinstance(self.deconv_layers, nn.Identity):
|
|
|
|
for m in self.deconv_layers:
|
|
if isinstance(m, nn.BatchNorm2D):
|
|
ones_(m.weight)
|
|
ones_(m.bias)
|
|
if not isinstance(self.final_layer, nn.Conv2D):
|
|
|
|
for m in self.final_layer:
|
|
if isinstance(m, nn.Conv2D):
|
|
normal_(m.weight)
|
|
zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2D):
|
|
ones_(m.weight)
|
|
ones_(m.bias)
|
|
else:
|
|
normal_(self.final_layer.weight)
|
|
zeros_(self.final_layer.bias)
|