更换文档检测模型
This commit is contained in:
15
paddle_detection/ppdet/core/__init__.py
Normal file
15
paddle_detection/ppdet/core/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import config
|
||||
13
paddle_detection/ppdet/core/config/__init__.py
Normal file
13
paddle_detection/ppdet/core/config/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
248
paddle_detection/ppdet/core/config/schema.py
Normal file
248
paddle_detection/ppdet/core/config/schema.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import inspect
|
||||
import importlib
|
||||
import re
|
||||
|
||||
try:
|
||||
from docstring_parser import parse as doc_parse
|
||||
except Exception:
|
||||
|
||||
def doc_parse(*args):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
from typeguard import check_type
|
||||
except Exception:
|
||||
|
||||
def check_type(*args):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema']
|
||||
|
||||
|
||||
class SchemaValue(object):
|
||||
def __init__(self, name, doc='', type=None):
|
||||
super(SchemaValue, self).__init__()
|
||||
self.name = name
|
||||
self.doc = doc
|
||||
self.type = type
|
||||
|
||||
def set_default(self, value):
|
||||
self.default = value
|
||||
|
||||
def has_default(self):
|
||||
return hasattr(self, 'default')
|
||||
|
||||
|
||||
class SchemaDict(dict):
|
||||
def __init__(self, **kwargs):
|
||||
super(SchemaDict, self).__init__()
|
||||
self.schema = {}
|
||||
self.strict = False
|
||||
self.doc = ""
|
||||
self.update(kwargs)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# XXX also update regular dict to SchemaDict??
|
||||
if isinstance(value, dict) and key in self and isinstance(self[key],
|
||||
SchemaDict):
|
||||
self[key].update(value)
|
||||
else:
|
||||
super(SchemaDict, self).__setitem__(key, value)
|
||||
|
||||
def __missing__(self, key):
|
||||
if self.has_default(key):
|
||||
return self.schema[key].default
|
||||
elif key in self.schema:
|
||||
return self.schema[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def copy(self):
|
||||
newone = SchemaDict()
|
||||
newone.__dict__.update(self.__dict__)
|
||||
newone.update(self)
|
||||
return newone
|
||||
|
||||
def set_schema(self, key, value):
|
||||
assert isinstance(value, SchemaValue)
|
||||
self.schema[key] = value
|
||||
|
||||
def set_strict(self, strict):
|
||||
self.strict = strict
|
||||
|
||||
def has_default(self, key):
|
||||
return key in self.schema and self.schema[key].has_default()
|
||||
|
||||
def is_default(self, key):
|
||||
if not self.has_default(key):
|
||||
return False
|
||||
if hasattr(self[key], '__dict__'):
|
||||
return True
|
||||
else:
|
||||
return key not in self or self[key] == self.schema[key].default
|
||||
|
||||
def find_default_keys(self):
|
||||
return [
|
||||
k for k in list(self.keys()) + list(self.schema.keys())
|
||||
if self.is_default(k)
|
||||
]
|
||||
|
||||
def mandatory(self):
|
||||
return any([k for k in self.schema.keys() if not self.has_default(k)])
|
||||
|
||||
def find_missing_keys(self):
|
||||
missing = [
|
||||
k for k in self.schema.keys()
|
||||
if k not in self and not self.has_default(k)
|
||||
]
|
||||
placeholders = [k for k in self if self[k] in ('<missing>', '<value>')]
|
||||
return missing + placeholders
|
||||
|
||||
def find_extra_keys(self):
|
||||
return list(set(self.keys()) - set(self.schema.keys()))
|
||||
|
||||
def find_mismatch_keys(self):
|
||||
mismatch_keys = []
|
||||
for arg in self.schema.values():
|
||||
if arg.type is not None:
|
||||
try:
|
||||
check_type("{}.{}".format(self.name, arg.name),
|
||||
self[arg.name], arg.type)
|
||||
except Exception:
|
||||
mismatch_keys.append(arg.name)
|
||||
return mismatch_keys
|
||||
|
||||
def validate(self):
|
||||
missing_keys = self.find_missing_keys()
|
||||
if missing_keys:
|
||||
raise ValueError("Missing param for class<{}>: {}".format(
|
||||
self.name, ", ".join(missing_keys)))
|
||||
extra_keys = self.find_extra_keys()
|
||||
if extra_keys and self.strict:
|
||||
raise ValueError("Extraneous param for class<{}>: {}".format(
|
||||
self.name, ", ".join(extra_keys)))
|
||||
mismatch_keys = self.find_mismatch_keys()
|
||||
if mismatch_keys:
|
||||
raise TypeError("Wrong param type for class<{}>: {}".format(
|
||||
self.name, ", ".join(mismatch_keys)))
|
||||
|
||||
|
||||
class SharedConfig(object):
|
||||
"""
|
||||
Representation class for `__shared__` annotations, which work as follows:
|
||||
|
||||
- if `key` is set for the module in config file, its value will take
|
||||
precedence
|
||||
- if `key` is not set for the module but present in the config file, its
|
||||
value will be used
|
||||
- otherwise, use the provided `default_value` as fallback
|
||||
|
||||
Args:
|
||||
key: config[key] will be injected
|
||||
default_value: fallback value
|
||||
"""
|
||||
|
||||
def __init__(self, key, default_value=None):
|
||||
super(SharedConfig, self).__init__()
|
||||
self.key = key
|
||||
self.default_value = default_value
|
||||
|
||||
|
||||
def extract_schema(cls):
|
||||
"""
|
||||
Extract schema from a given class
|
||||
|
||||
Args:
|
||||
cls (type): Class from which to extract.
|
||||
|
||||
Returns:
|
||||
schema (SchemaDict): Extracted schema.
|
||||
"""
|
||||
ctor = cls.__init__
|
||||
# python 2 compatibility
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
argspec = inspect.getfullargspec(ctor)
|
||||
annotations = argspec.annotations
|
||||
has_kwargs = argspec.varkw is not None
|
||||
else:
|
||||
argspec = inspect.getfullargspec(ctor)
|
||||
# python 2 type hinting workaround, see pep-3107
|
||||
# however, since `typeguard` does not support python 2, type checking
|
||||
# is still python 3 only for now
|
||||
annotations = getattr(ctor, '__annotations__', {})
|
||||
has_kwargs = argspec.varkw is not None
|
||||
|
||||
names = [arg for arg in argspec.args if arg != 'self']
|
||||
defaults = argspec.defaults
|
||||
num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0
|
||||
num_required = len(names) - num_defaults
|
||||
|
||||
docs = cls.__doc__
|
||||
if docs is None and getattr(cls, '__category__', None) == 'op':
|
||||
docs = cls.__call__.__doc__
|
||||
try:
|
||||
docstring = doc_parse(docs)
|
||||
except Exception:
|
||||
docstring = None
|
||||
|
||||
if docstring is None:
|
||||
comments = {}
|
||||
else:
|
||||
comments = {}
|
||||
for p in docstring.params:
|
||||
match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name)
|
||||
if match_obj is not None:
|
||||
comments[match_obj.group(1)] = p.description
|
||||
|
||||
schema = SchemaDict()
|
||||
schema.name = cls.__name__
|
||||
schema.doc = ""
|
||||
if docs is not None:
|
||||
start_pos = docs[0] == '\n' and 1 or 0
|
||||
schema.doc = docs[start_pos:].split("\n")[0].strip()
|
||||
# XXX handle paddle's weird doc convention
|
||||
if '**' == schema.doc[:2] and '**' == schema.doc[-2:]:
|
||||
schema.doc = schema.doc[2:-2].strip()
|
||||
schema.category = hasattr(cls, '__category__') and getattr(
|
||||
cls, '__category__') or 'module'
|
||||
schema.strict = not has_kwargs
|
||||
schema.pymodule = importlib.import_module(cls.__module__)
|
||||
schema.inject = getattr(cls, '__inject__', [])
|
||||
schema.shared = getattr(cls, '__shared__', [])
|
||||
for idx, name in enumerate(names):
|
||||
comment = name in comments and comments[name] or name
|
||||
if name in schema.inject:
|
||||
type_ = None
|
||||
else:
|
||||
type_ = name in annotations and annotations[name] or None
|
||||
value_schema = SchemaValue(name, comment, type_)
|
||||
if name in schema.shared:
|
||||
assert idx >= num_required, "shared config must have default value"
|
||||
default = defaults[idx - num_required]
|
||||
value_schema.set_default(SharedConfig(name, default))
|
||||
elif idx >= num_required:
|
||||
default = defaults[idx - num_required]
|
||||
value_schema.set_default(default)
|
||||
schema.set_schema(name, value_schema)
|
||||
|
||||
return schema
|
||||
118
paddle_detection/ppdet/core/config/yaml_helpers.py
Normal file
118
paddle_detection/ppdet/core/config/yaml_helpers.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
import yaml
|
||||
from .schema import SharedConfig
|
||||
|
||||
__all__ = ['serializable', 'Callable']
|
||||
|
||||
|
||||
def represent_dictionary_order(self, dict_data):
|
||||
return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
|
||||
|
||||
|
||||
def setup_orderdict():
|
||||
from collections import OrderedDict
|
||||
yaml.add_representer(OrderedDict, represent_dictionary_order)
|
||||
|
||||
|
||||
def _make_python_constructor(cls):
|
||||
def python_constructor(loader, node):
|
||||
if isinstance(node, yaml.SequenceNode):
|
||||
args = loader.construct_sequence(node, deep=True)
|
||||
return cls(*args)
|
||||
else:
|
||||
kwargs = loader.construct_mapping(node, deep=True)
|
||||
try:
|
||||
return cls(**kwargs)
|
||||
except Exception as ex:
|
||||
print("Error when construct {} instance from yaml config".
|
||||
format(cls.__name__))
|
||||
raise ex
|
||||
|
||||
return python_constructor
|
||||
|
||||
|
||||
def _make_python_representer(cls):
|
||||
# python 2 compatibility
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
argspec = inspect.getfullargspec(cls)
|
||||
else:
|
||||
argspec = inspect.getfullargspec(cls.__init__)
|
||||
argnames = [arg for arg in argspec.args if arg != 'self']
|
||||
|
||||
def python_representer(dumper, obj):
|
||||
if argnames:
|
||||
data = {name: getattr(obj, name) for name in argnames}
|
||||
else:
|
||||
data = obj.__dict__
|
||||
if '_id' in data:
|
||||
del data['_id']
|
||||
return dumper.represent_mapping(u'!{}'.format(cls.__name__), data)
|
||||
|
||||
return python_representer
|
||||
|
||||
|
||||
def serializable(cls):
|
||||
"""
|
||||
Add loader and dumper for given class, which must be
|
||||
"trivially serializable"
|
||||
|
||||
Args:
|
||||
cls: class to be serialized
|
||||
|
||||
Returns: cls
|
||||
"""
|
||||
yaml.add_constructor(u'!{}'.format(cls.__name__),
|
||||
_make_python_constructor(cls))
|
||||
yaml.add_representer(cls, _make_python_representer(cls))
|
||||
return cls
|
||||
|
||||
|
||||
yaml.add_representer(SharedConfig,
|
||||
lambda d, o: d.represent_data(o.default_value))
|
||||
|
||||
|
||||
@serializable
|
||||
class Callable(object):
|
||||
"""
|
||||
Helper to be used in Yaml for creating arbitrary class objects
|
||||
|
||||
Args:
|
||||
full_type (str): the full module path to target function
|
||||
"""
|
||||
|
||||
def __init__(self, full_type, args=[], kwargs={}):
|
||||
super(Callable, self).__init__()
|
||||
self.full_type = full_type
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self):
|
||||
if '.' in self.full_type:
|
||||
idx = self.full_type.rfind('.')
|
||||
module = importlib.import_module(self.full_type[:idx])
|
||||
func_name = self.full_type[idx + 1:]
|
||||
else:
|
||||
try:
|
||||
module = importlib.import_module('builtins')
|
||||
except Exception:
|
||||
module = importlib.import_module('__builtin__')
|
||||
func_name = self.full_type
|
||||
|
||||
func = getattr(module, func_name)
|
||||
return func(*self.args, **self.kwargs)
|
||||
292
paddle_detection/ppdet/core/workspace.py
Normal file
292
paddle_detection/ppdet/core/workspace.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
import collections
|
||||
|
||||
try:
|
||||
collectionsAbc = collections.abc
|
||||
except AttributeError:
|
||||
collectionsAbc = collections
|
||||
|
||||
from .config.schema import SchemaDict, SharedConfig, extract_schema
|
||||
from .config.yaml_helpers import serializable
|
||||
|
||||
__all__ = [
|
||||
'global_config',
|
||||
'load_config',
|
||||
'merge_config',
|
||||
'get_registered_modules',
|
||||
'create',
|
||||
'register',
|
||||
'serializable',
|
||||
'dump_value',
|
||||
]
|
||||
|
||||
|
||||
def dump_value(value):
|
||||
# XXX this is hackish, but collections.abc is not available in python 2
|
||||
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
|
||||
value = yaml.dump(value, default_flow_style=True)
|
||||
value = value.replace('\n', '')
|
||||
value = value.replace('...', '')
|
||||
return "'{}'".format(value)
|
||||
else:
|
||||
# primitive types
|
||||
return str(value)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""Single level attribute dict, NOT recursive"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(AttrDict, self).__init__()
|
||||
super(AttrDict, self).update(kwargs)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError("object has no attribute '{}'".format(key))
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def copy(self):
|
||||
new_dict = AttrDict()
|
||||
for k, v in self.items():
|
||||
new_dict.update({k: v})
|
||||
return new_dict
|
||||
|
||||
|
||||
global_config = AttrDict()
|
||||
|
||||
BASE_KEY = '_BASE_'
|
||||
|
||||
|
||||
# parse and load _BASE_ recursively
|
||||
def _load_config_with_base(file_path):
|
||||
with open(file_path) as f:
|
||||
file_cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
# NOTE: cfgs outside have higher priority than cfgs in _BASE_
|
||||
if BASE_KEY in file_cfg:
|
||||
all_base_cfg = AttrDict()
|
||||
base_ymls = list(file_cfg[BASE_KEY])
|
||||
for base_yml in base_ymls:
|
||||
if base_yml.startswith("~"):
|
||||
base_yml = os.path.expanduser(base_yml)
|
||||
if not base_yml.startswith('/'):
|
||||
base_yml = os.path.join(os.path.dirname(file_path), base_yml)
|
||||
|
||||
with open(base_yml) as f:
|
||||
base_cfg = _load_config_with_base(base_yml)
|
||||
all_base_cfg = merge_config(base_cfg, all_base_cfg)
|
||||
|
||||
del file_cfg[BASE_KEY]
|
||||
return merge_config(file_cfg, all_base_cfg)
|
||||
|
||||
return file_cfg
|
||||
|
||||
|
||||
def load_config(file_path):
|
||||
"""
|
||||
Load config from file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path of the config file to be loaded.
|
||||
|
||||
Returns: global config
|
||||
"""
|
||||
_, ext = os.path.splitext(file_path)
|
||||
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
||||
|
||||
# load config from file and merge into global config
|
||||
cfg = _load_config_with_base(file_path)
|
||||
cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
|
||||
merge_config(cfg)
|
||||
|
||||
return global_config
|
||||
|
||||
|
||||
def dict_merge(dct, merge_dct):
|
||||
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
|
||||
updating only top-level keys, dict_merge recurses down into dicts nested
|
||||
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
|
||||
``dct``.
|
||||
|
||||
Args:
|
||||
dct: dict onto which the merge is executed
|
||||
merge_dct: dct merged into dct
|
||||
|
||||
Returns: dct
|
||||
"""
|
||||
for k, v in merge_dct.items():
|
||||
if (k in dct and isinstance(dct[k], dict) and
|
||||
isinstance(merge_dct[k], collectionsAbc.Mapping)):
|
||||
dict_merge(dct[k], merge_dct[k])
|
||||
else:
|
||||
dct[k] = merge_dct[k]
|
||||
return dct
|
||||
|
||||
|
||||
def merge_config(config, another_cfg=None):
|
||||
"""
|
||||
Merge config into global config or another_cfg.
|
||||
|
||||
Args:
|
||||
config (dict): Config to be merged.
|
||||
|
||||
Returns: global config
|
||||
"""
|
||||
global global_config
|
||||
dct = another_cfg or global_config
|
||||
return dict_merge(dct, config)
|
||||
|
||||
|
||||
def get_registered_modules():
|
||||
return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)}
|
||||
|
||||
|
||||
def make_partial(cls):
|
||||
op_module = importlib.import_module(cls.__op__.__module__)
|
||||
op = getattr(op_module, cls.__op__.__name__)
|
||||
cls.__category__ = getattr(cls, '__category__', None) or 'op'
|
||||
|
||||
def partial_apply(self, *args, **kwargs):
|
||||
kwargs_ = self.__dict__.copy()
|
||||
kwargs_.update(kwargs)
|
||||
return op(*args, **kwargs_)
|
||||
|
||||
if getattr(cls, '__append_doc__', True): # XXX should default to True?
|
||||
if sys.version_info[0] > 2:
|
||||
cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__)
|
||||
cls.__init__.__doc__ = op.__doc__
|
||||
cls.__call__ = partial_apply
|
||||
cls.__call__.__doc__ = op.__doc__
|
||||
else:
|
||||
# XXX work around for python 2
|
||||
partial_apply.__doc__ = op.__doc__
|
||||
cls.__call__ = partial_apply
|
||||
return cls
|
||||
|
||||
|
||||
def register(cls):
|
||||
"""
|
||||
Register a given module class.
|
||||
|
||||
Args:
|
||||
cls (type): Module class to be registered.
|
||||
|
||||
Returns: cls
|
||||
"""
|
||||
if cls.__name__ in global_config:
|
||||
raise ValueError("Module class already registered: {}".format(
|
||||
cls.__name__))
|
||||
if hasattr(cls, '__op__'):
|
||||
cls = make_partial(cls)
|
||||
global_config[cls.__name__] = extract_schema(cls)
|
||||
return cls
|
||||
|
||||
|
||||
def create(cls_or_name, **kwargs):
|
||||
"""
|
||||
Create an instance of given module class.
|
||||
|
||||
Args:
|
||||
cls_or_name (type or str): Class of which to create instance.
|
||||
|
||||
Returns: instance of type `cls_or_name`
|
||||
"""
|
||||
assert type(cls_or_name) in [type, str
|
||||
], "should be a class or name of a class"
|
||||
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
|
||||
if name in global_config:
|
||||
if isinstance(global_config[name], SchemaDict):
|
||||
pass
|
||||
elif hasattr(global_config[name], "__dict__"):
|
||||
# support instance return directly
|
||||
return global_config[name]
|
||||
else:
|
||||
raise ValueError("The module {} is not registered".format(name))
|
||||
else:
|
||||
raise ValueError("The module {} is not registered".format(name))
|
||||
|
||||
config = global_config[name]
|
||||
cls = getattr(config.pymodule, name)
|
||||
cls_kwargs = {}
|
||||
cls_kwargs.update(global_config[name])
|
||||
|
||||
# parse `shared` annoation of registered modules
|
||||
if getattr(config, 'shared', None):
|
||||
for k in config.shared:
|
||||
target_key = config[k]
|
||||
shared_conf = config.schema[k].default
|
||||
assert isinstance(shared_conf, SharedConfig)
|
||||
if target_key is not None and not isinstance(target_key,
|
||||
SharedConfig):
|
||||
continue # value is given for the module
|
||||
elif shared_conf.key in global_config:
|
||||
# `key` is present in config
|
||||
cls_kwargs[k] = global_config[shared_conf.key]
|
||||
else:
|
||||
cls_kwargs[k] = shared_conf.default_value
|
||||
|
||||
# parse `inject` annoation of registered modules
|
||||
if getattr(cls, 'from_config', None):
|
||||
cls_kwargs.update(cls.from_config(config, **kwargs))
|
||||
|
||||
if getattr(config, 'inject', None):
|
||||
for k in config.inject:
|
||||
target_key = config[k]
|
||||
# optional dependency
|
||||
if target_key is None:
|
||||
continue
|
||||
|
||||
if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
|
||||
if 'name' not in target_key.keys():
|
||||
continue
|
||||
inject_name = str(target_key['name'])
|
||||
if inject_name not in global_config:
|
||||
raise ValueError(
|
||||
"Missing injection name {} and check it's name in cfg file".
|
||||
format(k))
|
||||
target = global_config[inject_name]
|
||||
for i, v in target_key.items():
|
||||
if i == 'name':
|
||||
continue
|
||||
target[i] = v
|
||||
if isinstance(target, SchemaDict):
|
||||
cls_kwargs[k] = create(inject_name)
|
||||
elif isinstance(target_key, str):
|
||||
if target_key not in global_config:
|
||||
raise ValueError("Missing injection config:", target_key)
|
||||
target = global_config[target_key]
|
||||
if isinstance(target, SchemaDict):
|
||||
cls_kwargs[k] = create(target_key)
|
||||
elif hasattr(target, '__dict__'): # serialized object
|
||||
cls_kwargs[k] = target
|
||||
else:
|
||||
raise ValueError("Unsupported injection type:", target_key)
|
||||
# prevent modification of global config values of reference types
|
||||
# (e.g., list, dict) from within the created module instances
|
||||
#kwargs = copy.deepcopy(kwargs)
|
||||
return cls(**cls_kwargs)
|
||||
Reference in New Issue
Block a user