# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register from ppdet.modeling.keypoint_utils import resize, flip_back from paddle.nn.initializer import TruncatedNormal, Constant, Normal from ppdet.modeling.layers import ConvTranspose2d, BatchNorm2d trunc_normal_ = TruncatedNormal(std=.02) normal_ = Normal(std=0.001) zeros_ = Constant(value=0.) ones_ = Constant(value=1.) __all__ = ['TopdownHeatmapSimpleHead'] @register class TopdownHeatmapSimpleHead(nn.Layer): def __init__(self, in_channels=768, out_channels=17, num_deconv_layers=3, num_deconv_filters=(256, 256, 256), num_deconv_kernels=(4, 4, 4), extra=None, in_index=0, input_transform=None, align_corners=False, upsample=0, flip_pairs=None, shift_heatmap=False, target_type='GaussianHeatmap'): super(TopdownHeatmapSimpleHead, self).__init__() self.in_channels = in_channels self.upsample = upsample self.flip_pairs = flip_pairs self.shift_heatmap = shift_heatmap self.target_type = target_type self._init_inputs(in_channels, in_index, input_transform) self.in_index = in_index self.align_corners = align_corners if extra is not None and not isinstance(extra, dict): raise TypeError('extra should be dict or None.') if num_deconv_layers > 0: self.deconv_layers = self._make_deconv_layer( num_deconv_layers, num_deconv_filters, num_deconv_kernels, ) elif num_deconv_layers == 0: self.deconv_layers = nn.Identity() else: raise ValueError( f'num_deconv_layers ({num_deconv_layers}) should >= 0.') identity_final_layer = False if extra is not None and 'final_conv_kernel' in extra: assert extra['final_conv_kernel'] in [0, 1, 3] if extra['final_conv_kernel'] == 3: padding = 1 elif extra['final_conv_kernel'] == 1: padding = 0 else: # 0 for Identity mapping. identity_final_layer = True kernel_size = extra['final_conv_kernel'] else: kernel_size = 1 padding = 0 if identity_final_layer: self.final_layer = nn.Identity() else: conv_channels = num_deconv_filters[ -1] if num_deconv_layers > 0 else self.in_channels layers = [] if extra is not None: num_conv_layers = extra.get('num_conv_layers', 0) num_conv_kernels = extra.get('num_conv_kernels', [1] * num_conv_layers) for i in range(num_conv_layers): layers.append( nn.Conv2D( in_channels=conv_channels, out_channels=conv_channels, kernel_size=num_conv_kernels[i], stride=1, padding=(num_conv_kernels[i] - 1) // 2)) layers.append(nn.BatchNorm2D(conv_channels)) layers.append(nn.ReLU()) layers.append( nn.Conv2D( in_channels=conv_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=(padding, padding))) if len(layers) > 1: self.final_layer = nn.Sequential(*layers) else: self.final_layer = layers[0] self.init_weights() @staticmethod def _get_deconv_cfg(deconv_kernel): """Get configurations for deconv layers.""" if deconv_kernel == 4: padding = 1 output_padding = 0 elif deconv_kernel == 3: padding = 1 output_padding = 1 elif deconv_kernel == 2: padding = 0 output_padding = 0 else: raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') return deconv_kernel, padding, output_padding def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def _transform_inputs(self, inputs): """Transform inputs for decoder. """ if not isinstance(inputs, list): if not isinstance(inputs, list): if self.upsample > 0: inputs = resize( input=F.relu(inputs), scale_factor=self.upsample, mode='bilinear', align_corners=self.align_corners) return inputs if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = paddle.concat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs def forward(self, x): """Forward function.""" x = self._transform_inputs(x) x = self.deconv_layers(x) x = self.final_layer(x) return x def inference_model(self, x, flip_pairs=None): """Inference function. Returns: output_heatmap (np.ndarray): Output heatmaps. Args: x (torch.Tensor[N,K,H,W]): Input features. flip_pairs (None | list[tuple]): Pairs of keypoints which are mirrored. """ output = self.forward(x) if flip_pairs is not None: output_heatmap = flip_back( output, self.flip_pairs, target_type=self.target_type) # feature is not aligned, shift flipped heatmap for higher accuracy if self.shift_heatmap: output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] else: output_heatmap = output return output_heatmap def _make_deconv_layer(self, num_layers, num_filters, num_kernels): """Make deconv layers.""" if num_layers != len(num_filters): error_msg = f'num_layers({num_layers}) ' \ f'!= length of num_filters({len(num_filters)})' raise ValueError(error_msg) if num_layers != len(num_kernels): error_msg = f'num_layers({num_layers}) ' \ f'!= length of num_kernels({len(num_kernels)})' raise ValueError(error_msg) layers = [] for i in range(num_layers): kernel, padding, output_padding = \ self._get_deconv_cfg(num_kernels[i]) planes = num_filters[i] layers.append( ConvTranspose2d( in_channels=self.in_channels, out_channels=planes, kernel_size=kernel, stride=2, padding=padding, output_padding=output_padding, bias=False)) layers.append(nn.BatchNorm2D(planes)) layers.append(nn.ReLU()) self.in_channels = planes return nn.Sequential(*layers) def init_weights(self): """Initialize model weights.""" if not isinstance(self.deconv_layers, nn.Identity): for m in self.deconv_layers: if isinstance(m, nn.BatchNorm2D): ones_(m.weight) ones_(m.bias) if not isinstance(self.final_layer, nn.Conv2D): for m in self.final_layer: if isinstance(m, nn.Conv2D): normal_(m.weight) zeros_(m.bias) elif isinstance(m, nn.BatchNorm2D): ones_(m.weight) ones_(m.bias) else: normal_(self.final_layer.weight) zeros_(self.final_layer.bias)