文档检测
This commit is contained in:
110
object_detection/evaluate.py
Normal file
110
object_detection/evaluate.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import torch
|
||||
|
||||
from core.config import SystemConfig
|
||||
from core.dbs import datasets
|
||||
from core.nnet.py_factory import NetworkFactory
|
||||
from core.test import test_func
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Evaluation Script")
|
||||
parser.add_argument("cfg_file", help="config file", type=str)
|
||||
parser.add_argument("--testiter", dest="testiter",
|
||||
help="test at iteration i",
|
||||
default=None, type=int)
|
||||
parser.add_argument("--split", dest="split",
|
||||
help="which split to use",
|
||||
default="validation", type=str)
|
||||
parser.add_argument("--suffix", dest="suffix", default=None, type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def make_dirs(directories):
|
||||
for directory in directories:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
|
||||
def test(db, system_config, model, args):
|
||||
split = args.split
|
||||
testiter = args.testiter
|
||||
debug = args.debug
|
||||
suffix = args.suffix
|
||||
|
||||
result_dir = system_config.result_dir
|
||||
result_dir = os.path.join(result_dir, str(testiter), split)
|
||||
|
||||
if suffix is not None:
|
||||
result_dir = os.path.join(result_dir, suffix)
|
||||
|
||||
make_dirs([result_dir])
|
||||
|
||||
test_iter = system_config.max_iter if testiter is None else testiter
|
||||
print("loading parameters at iteration: {}".format(test_iter))
|
||||
|
||||
print("building neural network...")
|
||||
nnet = NetworkFactory(system_config, model)
|
||||
print("loading parameters...")
|
||||
nnet.load_params(test_iter)
|
||||
|
||||
nnet.cuda()
|
||||
nnet.eval_mode()
|
||||
test_func(system_config, db, nnet, result_dir, debug=debug)
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.suffix is None:
|
||||
cfg_file = os.path.join("./configs", args.cfg_file + ".json")
|
||||
else:
|
||||
cfg_file = os.path.join("./configs", args.cfg_file + "-{}.json".format(args.suffix))
|
||||
print("cfg_file: {}".format(cfg_file))
|
||||
|
||||
with open(cfg_file, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
config["system"]["snapshot_name"] = args.cfg_file
|
||||
system_config = SystemConfig().update_config(config["system"])
|
||||
|
||||
model_file = "core.models.{}".format(args.cfg_file)
|
||||
model_file = importlib.import_module(model_file)
|
||||
model = model_file.model()
|
||||
|
||||
train_split = system_config.train_split
|
||||
val_split = system_config.val_split
|
||||
test_split = system_config.test_split
|
||||
|
||||
split = {
|
||||
"training": train_split,
|
||||
"validation": val_split,
|
||||
"testing": test_split
|
||||
}[args.split]
|
||||
|
||||
print("loading all datasets...")
|
||||
dataset = system_config.dataset
|
||||
print("split: {}".format(split))
|
||||
testing_db = datasets[dataset](config["db"], split=split, sys_config=system_config)
|
||||
|
||||
print("system config...")
|
||||
pprint.pprint(system_config.full)
|
||||
|
||||
print("db config...")
|
||||
pprint.pprint(testing_db.configs)
|
||||
|
||||
test(testing_db, system_config, model, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user