diff --git a/test/visual_model_test/visual_model_test.py b/test/visual_model_test/visual_model_test.py index 91aaabc..86de392 100644 --- a/test/visual_model_test/visual_model_test.py +++ b/test/visual_model_test/visual_model_test.py @@ -9,8 +9,8 @@ from paddlenlp.utils.doc_parser import DocParser from photo_review.util.ucloud import get_private_url -def visual_model_test(task_path, test_img, schema): - img = re.split(r'[\\/]', test_img)[-1] +def write_visual_result(image, layout=None, result=None): + img = re.split(r'[\\/]', image)[-1] img = img.split("?")[0] img_name = "" img_type = "jpg" @@ -19,30 +19,42 @@ def visual_model_test(task_path, test_img, schema): img_name = img[:last_dot_index] img_type = img[last_dot_index + 1:] - # 默认模型 - ie = Taskflow("information_extraction", schema=schema, model="uie-x-base") - results = ie({"doc": test_img}) - pprint(results[0]) - DocParser.write_image_with_results( - test_img, - result=results[0], - save_path="./img_result/" + img_name + "_default." + img_type) - # 自己训练的模型 - my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path) - my_results = my_ie({"doc": test_img}) - pprint(my_results[0]) - DocParser.write_image_with_results( - test_img, - result=my_results[0], - save_path="./img_result/" + img_name + "_my." + img_type) + if layout: + print(layout) + DocParser.write_image_with_results( + image, + layout=layout, + save_path="./img_result/" + img_name + "_layout." + img_type) + + if result: + print(result) + DocParser.write_image_with_results( + image, + result=result, + save_path="./img_result/" + img_name + "_result." + img_type) + + +def visual_model_test(model_type, test_img, task_path, schema): + if model_type == "ocr": + doc_parser = DocParser(layout_analysis=True) + parsed_doc = doc_parser.parse({"doc": test_img}) + write_visual_result(test_img, layout=parsed_doc["layout"]) + else: + my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, + layout_analysis=True) + my_results = my_ie({"doc": test_img}) + write_visual_result(test_img, result=my_results[0]) def main(model_type, pic_name=None): # 开始时间 start_time = time.time() - # 结算清单 - if model_type == "settlement": + if model_type == "ocr": + task_path = None + test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg" + schema = None + elif model_type == "settlement": task_path = "../../config/model/settlement_list_model" test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg" schema = ["姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费", "医保类型"] @@ -54,10 +66,14 @@ def main(model_type, pic_name=None): task_path = "../../config/model/cost_list_model" test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg" schema = ["姓名", "入院日期", "出院日期", "费用总额"] + elif model_type == "cost_detail": + task_path = "../../config/model/cost_list_detail_model" + test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg" + schema = {"名称": ["类别", "规格", "单价", "数量", "金额"]} else: print("请输入正确的类型!") return - visual_model_test(task_path, test_img_path, schema) + visual_model_test(model_type, test_img_path, task_path, schema) # 结束时间 end_time = time.time() @@ -65,6 +81,8 @@ def main(model_type, pic_name=None): if __name__ == '__main__': - main("settlement") + main("ocr") + # main("settlement") # main("discharge") # main("cost") + # main("cost_detail")