文档检测
This commit is contained in:
39
object_detection/core/base.py
Normal file
39
object_detection/core/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
|
||||
from .nnet.py_factory import NetworkFactory
|
||||
|
||||
|
||||
class Base(object):
|
||||
def __init__(self, db, nnet, func, model=None):
|
||||
super(Base, self).__init__()
|
||||
|
||||
self._db = db
|
||||
self._nnet = nnet
|
||||
self._func = func
|
||||
|
||||
if model is not None:
|
||||
self._nnet.load_pretrained_params(model)
|
||||
|
||||
self._nnet.cuda()
|
||||
self._nnet.eval_mode()
|
||||
|
||||
def _inference(self, image, *args, **kwargs):
|
||||
return self._func(self._db, self._nnet, image.copy(), *args, **kwargs)
|
||||
|
||||
def __call__(self, image, *args, **kwargs):
|
||||
categories = self._db.configs["categories"]
|
||||
bboxes = self._inference(image, *args, **kwargs)
|
||||
return {self._db.cls2name(j): bboxes[j] for j in range(1, categories + 1)}
|
||||
|
||||
|
||||
def load_cfg(cfg_file):
|
||||
with open(cfg_file, "r") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
cfg_sys = cfg["system"]
|
||||
cfg_db = cfg["db"]
|
||||
return cfg_sys, cfg_db
|
||||
|
||||
|
||||
def load_nnet(cfg_sys, model):
|
||||
return NetworkFactory(cfg_sys, model)
|
||||
Reference in New Issue
Block a user