更换文档检测模型
This commit is contained in:
89
paddle_detection/ppdet/slim/ofa.py
Normal file
89
paddle_detection/ppdet/slim/ofa.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from ppdet.core.workspace import load_config, merge_config, create
|
||||
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
|
||||
from ppdet.utils.logger import setup_logger
|
||||
from ppdet.core.workspace import register, serializable
|
||||
|
||||
from paddle.utils import try_import
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class OFA(object):
|
||||
def __init__(self, ofa_config):
|
||||
super(OFA, self).__init__()
|
||||
self.ofa_config = ofa_config
|
||||
|
||||
def __call__(self, model, param_state_dict):
|
||||
|
||||
paddleslim = try_import('paddleslim')
|
||||
from paddleslim.nas.ofa import OFA, RunConfig, utils
|
||||
from paddleslim.nas.ofa.convert_super import Convert, supernet
|
||||
task = self.ofa_config['task']
|
||||
expand_ratio = self.ofa_config['expand_ratio']
|
||||
|
||||
skip_neck = self.ofa_config['skip_neck']
|
||||
skip_head = self.ofa_config['skip_head']
|
||||
|
||||
run_config = self.ofa_config['RunConfig']
|
||||
if 'skip_layers' in run_config:
|
||||
skip_layers = run_config['skip_layers']
|
||||
else:
|
||||
skip_layers = []
|
||||
|
||||
# supernet config
|
||||
sp_config = supernet(expand_ratio=expand_ratio)
|
||||
# convert to supernet
|
||||
model = Convert(sp_config).convert(model)
|
||||
|
||||
skip_names = []
|
||||
if skip_neck:
|
||||
skip_names.append('neck.')
|
||||
if skip_head:
|
||||
skip_names.append('head.')
|
||||
|
||||
for name, sublayer in model.named_sublayers():
|
||||
for n in skip_names:
|
||||
if n in name:
|
||||
skip_layers.append(name)
|
||||
|
||||
run_config['skip_layers'] = skip_layers
|
||||
run_config = RunConfig(**run_config)
|
||||
|
||||
# build ofa model
|
||||
ofa_model = OFA(model, run_config=run_config)
|
||||
|
||||
ofa_model.set_epoch(0)
|
||||
ofa_model.set_task(task)
|
||||
|
||||
input_spec = [{
|
||||
"image": paddle.ones(
|
||||
shape=[1, 3, 640, 640], dtype='float32'),
|
||||
"im_shape": paddle.full(
|
||||
[1, 2], 640, dtype='float32'),
|
||||
"scale_factor": paddle.ones(
|
||||
shape=[1, 2], dtype='float32')
|
||||
}]
|
||||
|
||||
ofa_model._clear_search_space(input_spec=input_spec)
|
||||
ofa_model._build_ss = True
|
||||
check_ss = ofa_model._sample_config('expand_ratio', phase=None)
|
||||
# tokenize the search space
|
||||
ofa_model.tokenize()
|
||||
# check token map, search cands and search space
|
||||
logger.info('Token map is {}'.format(ofa_model.token_map))
|
||||
logger.info('Search candidates is {}'.format(ofa_model.search_cands))
|
||||
logger.info('The length of search_space is {}, search_space is {}'.
|
||||
format(len(ofa_model._ofa_layers), ofa_model._ofa_layers))
|
||||
# set model state dict into ofa model
|
||||
utils.set_state_dict(ofa_model.model, param_state_dict)
|
||||
return ofa_model
|
||||
Reference in New Issue
Block a user