# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # reference: https://arxiv.org/abs/2010.11929 from collections.abc import Callable import numpy as np import paddle import paddle.nn as nn from paddle.nn.initializer import TruncatedNormal, Constant, Normal from ppdet.core.workspace import register, serializable trunc_normal_ = TruncatedNormal(std=.02) def to_2tuple(x): if isinstance(x, (list, tuple)): return x return tuple([x] * 2) def drop_path(x, drop_prob=0., training=False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... """ if drop_prob == 0. or not training: return x keep_prob = paddle.to_tensor(1.0 - drop_prob).astype(x.dtype) shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype) random_tensor = paddle.floor(random_tensor) # binarize output = x.divide(keep_prob) * random_tensor return output class DropPath(nn.Layer): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Identity(nn.Layer): def __init__(self): super(Identity, self).__init__() def forward(self, input): return input class Mlp(nn.Layer): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Layer): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): N, C = x.shape[1:] qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).transpose((2, 0, 3, 1, 4)) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale attn = nn.functional.softmax(attn, axis=-1) attn = self.attn_drop(attn) x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C)) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Layer): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer='nn.LayerNorm', epsilon=1e-5): super().__init__() if isinstance(norm_layer, str): self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) elif isinstance(norm_layer, Callable): self.norm1 = norm_layer(dim) else: raise TypeError( "The norm_layer must be str or paddle.nn.layer.Layer class") self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() if isinstance(norm_layer, str): self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) elif isinstance(norm_layer, Callable): self.norm2 = norm_layer(dim) else: raise TypeError( "The norm_layer must be str or paddle.nn.layer.Layer class") mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Layer): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * ( img_size[0] // patch_size[0]) * (ratio**2) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2D( in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=(4 + 2 * (ratio // 2 - 1), 4 + 2 * (ratio // 2 - 1))) def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) return x @register @serializable class ViT(nn.Layer): """ Vision Transformer with support for patch input This module is different from ppdet's VisionTransformer (from ppdet/modeling/backbones/visio_transformer.py), the main differences are: 1.the module PatchEmbed.proj has padding set,padding=(4 + 2 * (ratio // 2 - 1), 4 + 2 * (ratio // 2 - 1), VisionTransformer dose not 2.Attention module qkv is standard.but VisionTransformer provide more options 3.MLP module only one Dropout,and VisionTransformer twice; 4.VisionTransformer provide fpn layer,but the module does not. """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer='nn.LayerNorm', epsilon=1e-5, ratio=1, pretrained=None, **kwargs): super().__init__() self.pretrained = pretrained self.num_features = self.embed_dim = embed_dim self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) num_patches = self.patch_embed.num_patches self.pos_embed = self.create_parameter( shape=(1, num_patches + 1, embed_dim), default_initializer=trunc_normal_) self.add_parameter("pos_embed", self.pos_embed) dpr = np.linspace(0, drop_path_rate, depth, dtype='float32') self.blocks = nn.LayerList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, epsilon=epsilon) for i in range(depth) ]) self.last_norm = eval(norm_layer)(embed_dim, epsilon=epsilon) trunc_normal_(self.pos_embed) self._init_weights() def _init_weights(self): pretrained = self.pretrained if pretrained: if 'http' in pretrained: #URL path = paddle.utils.download.get_weights_path_from_url( pretrained) else: #model in local path path = pretrained load_state_dict = paddle.load(path) self.set_state_dict(load_state_dict) print("Load load_state_dict:", path) def forward_features(self, x): B = paddle.shape(x)[0] x = self.patch_embed(x) B, D, Hp, Wp = x.shape x = x.flatten(2).transpose([0, 2, 1]) x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] for blk in self.blocks: x = blk(x) x = self.last_norm(x) xp = paddle.reshape( paddle.transpose( x, perm=[0, 2, 1]), shape=[B, -1, Hp, Wp]) return xp