优化批量处理能力

This commit is contained in:
2024-06-06 10:39:34 +08:00
parent adbf93eca1
commit 9a35071b93
4 changed files with 42 additions and 21 deletions

View File

@@ -49,6 +49,17 @@ def visual_model_test(model_type, test_img, task_path, schema):
write_visual_result(test_img, result=my_results[0])
def batch_test(test_imgs, task_path, schema):
docs = []
for test_img in test_imgs:
docs.append({"doc": test_img})
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
layout_analysis=True, batch_size=16)
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
my_results = my_ie(docs)
pprint(my_results)
def main(model_type, pic_name=None):
# 开始时间
start_time = time.time()
@@ -60,7 +71,7 @@ def main(model_type, pic_name=None):
elif model_type == "settlement":
task_path = "../config/model/settlement_list_model"
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
schema = ["姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费", "医保类型"]
schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", "医保类型"]
elif model_type == "discharge":
task_path = "../config/model/discharge_record_model"
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"