From 9a35071b93e4cea59a3158d3a7c1f81f87e5f59d Mon Sep 17 00:00:00 2001 From: liuyebo <1515783401@qq.com> Date: Thu, 6 Jun 2024 10:39:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=89=B9=E9=87=8F=E5=A4=84?= =?UTF-8?q?=E7=90=86=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/photo_review.py | 3 ++ photo_review/photo_review.py | 45 +++++++++++++++----------- photo_review/util/ucloud.py | 2 +- visual_model_test/visual_model_test.py | 13 +++++++- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/config/photo_review.py b/config/photo_review.py index 08dc4aa..02d542a 100644 --- a/config/photo_review.py +++ b/config/photo_review.py @@ -9,3 +9,6 @@ SLEEP_MINUTES = 5 # 是否发送报错邮件 SEND_ERROR_EMAIL = True + +# 信息抽取批量处理大小 +IE_BATCH_SIZE = 32 diff --git a/photo_review/photo_review.py b/photo_review/photo_review.py index 890ee25..a2710ba 100644 --- a/photo_review/photo_review.py +++ b/photo_review/photo_review.py @@ -8,7 +8,7 @@ from sqlalchemy import update from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \ PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR from config.mysql import MysqlSession -from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES +from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, IE_BATCH_SIZE from photo_review.entity.bd_yljg import BdYljg from photo_review.entity.bd_ylks import BdYlks from photo_review.entity.zx_ie_cost import ZxIeCost @@ -24,29 +24,36 @@ from photo_review.util.util import get_default_datetime # 关键信息提取 def information_extraction(schema, phrecs, task_path): - results = {} + result = {} + docs = [] + doc_phrecs = [] for phrec in phrecs: pic_path = get_private_url(phrec.cfjaddress) if pic_path: - ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, - layout_analysis=True) - # 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}]) - result = ie({"doc": pic_path}) + docs.append({"doc": pic_path}) + doc_phrecs.append(phrec) - result_json = json.dumps(result, ensure_ascii=False) - if len(result_json) > 5000: - result_json = result_json[:5000] - # 提取完保存每张图片的结果 - now = get_default_datetime() - session = MysqlSession() - zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress, - content=result_json, create_time=now, update_time=now) - session.add(zx_ocr) - session.commit() - session.close() + ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, + layout_analysis=True, batch_size=IE_BATCH_SIZE) + ie_results = ie(docs) - results.update(result[0]) - return results + now = get_default_datetime() + for i in range(len(ie_results)): + ie_result = ie_results[i] + phrec = doc_phrecs[i] + result_json = json.dumps(ie_result, ensure_ascii=False) + if len(result_json) > 5000: + result_json = result_json[:5000] + session = MysqlSession() + zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress, + content=result_json, create_time=now, update_time=now) + session.add(zx_ocr) + session.commit() + session.close() + + result.update(ie_result) + + return result # 从keys中获取准确率最高的value diff --git a/photo_review/util/ucloud.py b/photo_review/util/ucloud.py index 53d3420..e568ccd 100644 --- a/photo_review/util/ucloud.py +++ b/photo_review/util/ucloud.py @@ -23,5 +23,5 @@ def get_private_url(key): # url = get_ufile_handler.public_download_url(bucket, key) # 获取私有空间下载url, expires为下载链接有效期,单位为秒 - url = get_ufile_handler.private_download_url(bucket, key, expires=300) + url = get_ufile_handler.private_download_url(bucket, key, expires=3600) return url diff --git a/visual_model_test/visual_model_test.py b/visual_model_test/visual_model_test.py index 16253f6..9b5db4c 100644 --- a/visual_model_test/visual_model_test.py +++ b/visual_model_test/visual_model_test.py @@ -49,6 +49,17 @@ def visual_model_test(model_type, test_img, task_path, schema): write_visual_result(test_img, result=my_results[0]) +def batch_test(test_imgs, task_path, schema): + docs = [] + for test_img in test_imgs: + docs.append({"doc": test_img}) + my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, + layout_analysis=True, batch_size=16) + # 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}]) + my_results = my_ie(docs) + pprint(my_results) + + def main(model_type, pic_name=None): # 开始时间 start_time = time.time() @@ -60,7 +71,7 @@ def main(model_type, pic_name=None): elif model_type == "settlement": task_path = "../config/model/settlement_list_model" test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg" - schema = ["姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费", "医保类型"] + schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", "医保类型"] elif model_type == "discharge": task_path = "../config/model/discharge_record_model" test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"