文档检测

This commit is contained in:
2024-08-20 13:18:45 +08:00
parent e6891257b9
commit 299b762cad
59 changed files with 4966 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
import json
from .nnet.py_factory import NetworkFactory
class Base(object):
def __init__(self, db, nnet, func, model=None):
super(Base, self).__init__()
self._db = db
self._nnet = nnet
self._func = func
if model is not None:
self._nnet.load_pretrained_params(model)
self._nnet.cuda()
self._nnet.eval_mode()
def _inference(self, image, *args, **kwargs):
return self._func(self._db, self._nnet, image.copy(), *args, **kwargs)
def __call__(self, image, *args, **kwargs):
categories = self._db.configs["categories"]
bboxes = self._inference(image, *args, **kwargs)
return {self._db.cls2name(j): bboxes[j] for j in range(1, categories + 1)}
def load_cfg(cfg_file):
with open(cfg_file, "r") as f:
cfg = json.load(f)
cfg_sys = cfg["system"]
cfg_db = cfg["db"]
return cfg_sys, cfg_db
def load_nnet(cfg_sys, model):
return NetworkFactory(cfg_sys, model)