40 lines
998 B
Python
40 lines
998 B
Python
import json
|
|
|
|
from .nnet.py_factory import NetworkFactory
|
|
|
|
|
|
class Base(object):
|
|
def __init__(self, db, nnet, func, model=None):
|
|
super(Base, self).__init__()
|
|
|
|
self._db = db
|
|
self._nnet = nnet
|
|
self._func = func
|
|
|
|
if model is not None:
|
|
self._nnet.load_pretrained_params(model)
|
|
|
|
self._nnet.cuda()
|
|
self._nnet.eval_mode()
|
|
|
|
def _inference(self, image, *args, **kwargs):
|
|
return self._func(self._db, self._nnet, image.copy(), *args, **kwargs)
|
|
|
|
def __call__(self, image, *args, **kwargs):
|
|
categories = self._db.configs["categories"]
|
|
bboxes = self._inference(image, *args, **kwargs)
|
|
return {self._db.cls2name(j): bboxes[j] for j in range(1, categories + 1)}
|
|
|
|
|
|
def load_cfg(cfg_file):
|
|
with open(cfg_file, "r") as f:
|
|
cfg = json.load(f)
|
|
|
|
cfg_sys = cfg["system"]
|
|
cfg_db = cfg["db"]
|
|
return cfg_sys, cfg_db
|
|
|
|
|
|
def load_nnet(cfg_sys, model):
|
|
return NetworkFactory(cfg_sys, model)
|